1import threading
2import time
3from sqlalchemy import pool, select, event
4import sqlalchemy as tsa
5from sqlalchemy import testing
6from sqlalchemy.testing.util import gc_collect, lazy_gc
7from sqlalchemy.testing import eq_, assert_raises, is_not_, is_
8from sqlalchemy.testing.engines import testing_engine
9from sqlalchemy.testing import fixtures
10import random
11from sqlalchemy.testing.mock import Mock, call, patch, ANY
12import weakref
13import collections
14
15join_timeout = 10
16
17
18def MockDBAPI():  # noqa
19    def cursor():
20        return Mock()
21
22    def connect(*arg, **kw):
23        return Mock(cursor=Mock(side_effect=cursor))
24
25    def shutdown(value):
26        if value:
27            db.connect = Mock(side_effect=Exception("connect failed"))
28        else:
29            db.connect = Mock(side_effect=connect)
30        db.is_shutdown = value
31
32    db = Mock(
33        connect=Mock(side_effect=connect),
34        shutdown=shutdown,
35        is_shutdown=False)
36    return db
37
38
39class PoolTestBase(fixtures.TestBase):
40    def setup(self):
41        pool.clear_managers()
42        self._teardown_conns = []
43
44    def teardown(self):
45        for ref in self._teardown_conns:
46            conn = ref()
47            if conn:
48                conn.close()
49
50    @classmethod
51    def teardown_class(cls):
52        pool.clear_managers()
53
54    def _with_teardown(self, connection):
55        self._teardown_conns.append(weakref.ref(connection))
56        return connection
57
58    def _queuepool_fixture(self, **kw):
59        dbapi, pool = self._queuepool_dbapi_fixture(**kw)
60        return pool
61
62    def _queuepool_dbapi_fixture(self, **kw):
63        dbapi = MockDBAPI()
64        return dbapi, pool.QueuePool(
65            creator=lambda: dbapi.connect('foo.db'),
66            **kw)
67
68
69class PoolTest(PoolTestBase):
70    def test_manager(self):
71        manager = pool.manage(MockDBAPI(), use_threadlocal=True)
72
73        c1 = manager.connect('foo.db')
74        c2 = manager.connect('foo.db')
75        c3 = manager.connect('bar.db')
76        c4 = manager.connect("foo.db", bar="bat")
77        c5 = manager.connect("foo.db", bar="hoho")
78        c6 = manager.connect("foo.db", bar="bat")
79
80        assert c1.cursor() is not None
81        assert c1 is c2
82        assert c1 is not c3
83        assert c4 is c6
84        assert c4 is not c5
85
86    def test_manager_with_key(self):
87
88        dbapi = MockDBAPI()
89        manager = pool.manage(dbapi, use_threadlocal=True)
90
91        c1 = manager.connect('foo.db', sa_pool_key="a")
92        c2 = manager.connect('foo.db', sa_pool_key="b")
93        c3 = manager.connect('bar.db', sa_pool_key="a")
94
95        assert c1.cursor() is not None
96        assert c1 is not c2
97        assert c1 is c3
98
99        eq_(
100            dbapi.connect.mock_calls,
101            [
102                call("foo.db"),
103                call("foo.db"),
104            ]
105        )
106
107    def test_bad_args(self):
108        manager = pool.manage(MockDBAPI())
109        manager.connect(None)
110
111    def test_non_thread_local_manager(self):
112        manager = pool.manage(MockDBAPI(), use_threadlocal=False)
113
114        connection = manager.connect('foo.db')
115        connection2 = manager.connect('foo.db')
116
117        self.assert_(connection.cursor() is not None)
118        self.assert_(connection is not connection2)
119
120    @testing.fails_on('+pyodbc',
121                      "pyodbc cursor doesn't implement tuple __eq__")
122    @testing.fails_on("+pg8000", "returns [1], not (1,)")
123    def test_cursor_iterable(self):
124        conn = testing.db.raw_connection()
125        cursor = conn.cursor()
126        cursor.execute(str(select([1], bind=testing.db)))
127        expected = [(1, )]
128        for row in cursor:
129            eq_(row, expected.pop(0))
130
131    def test_no_connect_on_recreate(self):
132        def creator():
133            raise Exception("no creates allowed")
134
135        for cls in (pool.SingletonThreadPool, pool.StaticPool,
136                    pool.QueuePool, pool.NullPool, pool.AssertionPool):
137            p = cls(creator=creator)
138            p.dispose()
139            p2 = p.recreate()
140            assert p2.__class__ is cls
141
142            mock_dbapi = MockDBAPI()
143            p = cls(creator=mock_dbapi.connect)
144            conn = p.connect()
145            conn.close()
146            mock_dbapi.connect.side_effect = Exception("error!")
147            p.dispose()
148            p.recreate()
149
150    def test_threadlocal_del(self):
151        self._do_testthreadlocal(useclose=False)
152
153    def test_threadlocal_close(self):
154        self._do_testthreadlocal(useclose=True)
155
156    def _do_testthreadlocal(self, useclose=False):
157        dbapi = MockDBAPI()
158        for p in pool.QueuePool(creator=dbapi.connect,
159                                pool_size=3, max_overflow=-1,
160                                use_threadlocal=True), \
161            pool.SingletonThreadPool(
162                creator=dbapi.connect,
163                use_threadlocal=True):
164            c1 = p.connect()
165            c2 = p.connect()
166            self.assert_(c1 is c2)
167            c3 = p.unique_connection()
168            self.assert_(c3 is not c1)
169            if useclose:
170                c2.close()
171            else:
172                c2 = None
173            c2 = p.connect()
174            self.assert_(c1 is c2)
175            self.assert_(c3 is not c1)
176            if useclose:
177                c2.close()
178            else:
179                c2 = None
180                lazy_gc()
181            if useclose:
182                c1 = p.connect()
183                c2 = p.connect()
184                c3 = p.connect()
185                c3.close()
186                c2.close()
187                self.assert_(c1.connection is not None)
188                c1.close()
189            c1 = c2 = c3 = None
190
191            # extra tests with QueuePool to ensure connections get
192            # __del__()ed when dereferenced
193
194            if isinstance(p, pool.QueuePool):
195                lazy_gc()
196                self.assert_(p.checkedout() == 0)
197                c1 = p.connect()
198                c2 = p.connect()
199                if useclose:
200                    c2.close()
201                    c1.close()
202                else:
203                    c2 = None
204                    c1 = None
205                    lazy_gc()
206                self.assert_(p.checkedout() == 0)
207
208    def test_info(self):
209        p = self._queuepool_fixture(pool_size=1, max_overflow=0)
210
211        c = p.connect()
212        self.assert_(not c.info)
213        self.assert_(c.info is c._connection_record.info)
214
215        c.info['foo'] = 'bar'
216        c.close()
217        del c
218
219        c = p.connect()
220        self.assert_('foo' in c.info)
221
222        c.invalidate()
223        c = p.connect()
224        self.assert_('foo' not in c.info)
225
226        c.info['foo2'] = 'bar2'
227        c.detach()
228        self.assert_('foo2' in c.info)
229
230        c2 = p.connect()
231        is_not_(c.connection, c2.connection)
232        assert not c2.info
233        assert 'foo2' in c.info
234
235    def test_rec_info(self):
236        p = self._queuepool_fixture(pool_size=1, max_overflow=0)
237
238        c = p.connect()
239        self.assert_(not c.record_info)
240        self.assert_(c.record_info is c._connection_record.record_info)
241
242        c.record_info['foo'] = 'bar'
243        c.close()
244        del c
245
246        c = p.connect()
247        self.assert_('foo' in c.record_info)
248
249        c.invalidate()
250        c = p.connect()
251        self.assert_('foo' in c.record_info)
252
253        c.record_info['foo2'] = 'bar2'
254        c.detach()
255        is_(c.record_info, None)
256        is_(c._connection_record, None)
257
258        c2 = p.connect()
259
260        assert c2.record_info
261        assert 'foo2' in c2.record_info
262
263    def test_rec_unconnected(self):
264        # test production of a _ConnectionRecord with an
265        # initially unconnected state.
266
267        dbapi = MockDBAPI()
268        p1 = pool.Pool(
269            creator=lambda: dbapi.connect('foo.db')
270        )
271
272        r1 = pool._ConnectionRecord(p1, connect=False)
273
274        assert not r1.connection
275        c1 = r1.get_connection()
276        is_(c1, r1.connection)
277
278    def test_rec_close_reopen(self):
279        # test that _ConnectionRecord.close() allows
280        # the record to be reusable
281        dbapi = MockDBAPI()
282        p1 = pool.Pool(
283            creator=lambda: dbapi.connect('foo.db')
284        )
285
286        r1 = pool._ConnectionRecord(p1)
287
288        c1 = r1.connection
289        c2 = r1.get_connection()
290        is_(c1, c2)
291
292        r1.close()
293
294        assert not r1.connection
295        eq_(
296            c1.mock_calls,
297            [call.close()]
298        )
299
300        c2 = r1.get_connection()
301
302        is_not_(c1, c2)
303        is_(c2, r1.connection)
304
305        eq_(
306            c2.mock_calls,
307            []
308        )
309
310
311class PoolDialectTest(PoolTestBase):
312    def _dialect(self):
313        canary = []
314
315        class PoolDialect(object):
316            def do_rollback(self, dbapi_connection):
317                canary.append('R')
318                dbapi_connection.rollback()
319
320            def do_commit(self, dbapi_connection):
321                canary.append('C')
322                dbapi_connection.commit()
323
324            def do_close(self, dbapi_connection):
325                canary.append('CL')
326                dbapi_connection.close()
327        return PoolDialect(), canary
328
329    def _do_test(self, pool_cls, assertion):
330        mock_dbapi = MockDBAPI()
331        dialect, canary = self._dialect()
332
333        p = pool_cls(creator=mock_dbapi.connect)
334        p._dialect = dialect
335        conn = p.connect()
336        conn.close()
337        p.dispose()
338        p.recreate()
339        conn = p.connect()
340        conn.close()
341        eq_(canary, assertion)
342
343    def test_queue_pool(self):
344        self._do_test(pool.QueuePool, ['R', 'CL', 'R'])
345
346    def test_assertion_pool(self):
347        self._do_test(pool.AssertionPool, ['R', 'CL', 'R'])
348
349    def test_singleton_pool(self):
350        self._do_test(pool.SingletonThreadPool, ['R', 'CL', 'R'])
351
352    def test_null_pool(self):
353        self._do_test(pool.NullPool, ['R', 'CL', 'R', 'CL'])
354
355    def test_static_pool(self):
356        self._do_test(pool.StaticPool, ['R', 'R'])
357
358
359class PoolEventsTest(PoolTestBase):
360    def _first_connect_event_fixture(self):
361        p = self._queuepool_fixture()
362        canary = []
363
364        def first_connect(*arg, **kw):
365            canary.append('first_connect')
366
367        event.listen(p, 'first_connect', first_connect)
368
369        return p, canary
370
371    def _connect_event_fixture(self):
372        p = self._queuepool_fixture()
373        canary = []
374
375        def connect(*arg, **kw):
376            canary.append('connect')
377
378        event.listen(p, 'connect', connect)
379
380        return p, canary
381
382    def _checkout_event_fixture(self):
383        p = self._queuepool_fixture()
384        canary = []
385
386        def checkout(*arg, **kw):
387            canary.append('checkout')
388        event.listen(p, 'checkout', checkout)
389
390        return p, canary
391
392    def _checkin_event_fixture(self):
393        p = self._queuepool_fixture()
394        canary = []
395
396        def checkin(*arg, **kw):
397            canary.append('checkin')
398        event.listen(p, 'checkin', checkin)
399
400        return p, canary
401
402    def _reset_event_fixture(self):
403        p = self._queuepool_fixture()
404        canary = []
405
406        def reset(*arg, **kw):
407            canary.append('reset')
408        event.listen(p, 'reset', reset)
409
410        return p, canary
411
412    def _invalidate_event_fixture(self):
413        p = self._queuepool_fixture()
414        canary = Mock()
415        event.listen(p, 'invalidate', canary)
416
417        return p, canary
418
419    def _soft_invalidate_event_fixture(self):
420        p = self._queuepool_fixture()
421        canary = Mock()
422        event.listen(p, 'soft_invalidate', canary)
423
424        return p, canary
425
426    def _close_event_fixture(self):
427        p = self._queuepool_fixture()
428        canary = Mock()
429        event.listen(p, 'close', canary)
430
431        return p, canary
432
433    def _detach_event_fixture(self):
434        p = self._queuepool_fixture()
435        canary = Mock()
436        event.listen(p, 'detach', canary)
437
438        return p, canary
439
440    def _close_detached_event_fixture(self):
441        p = self._queuepool_fixture()
442        canary = Mock()
443        event.listen(p, 'close_detached', canary)
444
445        return p, canary
446
447    def test_close(self):
448        p, canary = self._close_event_fixture()
449
450        c1 = p.connect()
451
452        connection = c1.connection
453        rec = c1._connection_record
454
455        c1.close()
456
457        eq_(canary.mock_calls, [])
458
459        p.dispose()
460        eq_(canary.mock_calls, [call(connection, rec)])
461
462    def test_detach(self):
463        p, canary = self._detach_event_fixture()
464
465        c1 = p.connect()
466
467        connection = c1.connection
468        rec = c1._connection_record
469
470        c1.detach()
471
472        eq_(canary.mock_calls, [call(connection, rec)])
473
474    def test_detach_close(self):
475        p, canary = self._close_detached_event_fixture()
476
477        c1 = p.connect()
478
479        connection = c1.connection
480
481        c1.detach()
482
483        c1.close()
484        eq_(canary.mock_calls, [call(connection)])
485
486    def test_first_connect_event(self):
487        p, canary = self._first_connect_event_fixture()
488
489        p.connect()
490        eq_(canary, ['first_connect'])
491
492    def test_first_connect_event_fires_once(self):
493        p, canary = self._first_connect_event_fixture()
494
495        p.connect()
496        p.connect()
497
498        eq_(canary, ['first_connect'])
499
500    def test_first_connect_on_previously_recreated(self):
501        p, canary = self._first_connect_event_fixture()
502
503        p2 = p.recreate()
504        p.connect()
505        p2.connect()
506
507        eq_(canary, ['first_connect', 'first_connect'])
508
509    def test_first_connect_on_subsequently_recreated(self):
510        p, canary = self._first_connect_event_fixture()
511
512        p.connect()
513        p2 = p.recreate()
514        p2.connect()
515
516        eq_(canary, ['first_connect', 'first_connect'])
517
518    def test_connect_event(self):
519        p, canary = self._connect_event_fixture()
520
521        p.connect()
522        eq_(canary, ['connect'])
523
524    def test_connect_event_fires_subsequent(self):
525        p, canary = self._connect_event_fixture()
526
527        c1 = p.connect()  # noqa
528        c2 = p.connect()  # noqa
529
530        eq_(canary, ['connect', 'connect'])
531
532    def test_connect_on_previously_recreated(self):
533        p, canary = self._connect_event_fixture()
534
535        p2 = p.recreate()
536
537        p.connect()
538        p2.connect()
539
540        eq_(canary, ['connect', 'connect'])
541
542    def test_connect_on_subsequently_recreated(self):
543        p, canary = self._connect_event_fixture()
544
545        p.connect()
546        p2 = p.recreate()
547        p2.connect()
548
549        eq_(canary, ['connect', 'connect'])
550
551    def test_checkout_event(self):
552        p, canary = self._checkout_event_fixture()
553
554        p.connect()
555        eq_(canary, ['checkout'])
556
557    def test_checkout_event_fires_subsequent(self):
558        p, canary = self._checkout_event_fixture()
559
560        p.connect()
561        p.connect()
562        eq_(canary, ['checkout', 'checkout'])
563
564    def test_checkout_event_on_subsequently_recreated(self):
565        p, canary = self._checkout_event_fixture()
566
567        p.connect()
568        p2 = p.recreate()
569        p2.connect()
570
571        eq_(canary, ['checkout', 'checkout'])
572
573    def test_checkin_event(self):
574        p, canary = self._checkin_event_fixture()
575
576        c1 = p.connect()
577        eq_(canary, [])
578        c1.close()
579        eq_(canary, ['checkin'])
580
581    def test_reset_event(self):
582        p, canary = self._reset_event_fixture()
583
584        c1 = p.connect()
585        eq_(canary, [])
586        c1.close()
587        eq_(canary, ['reset'])
588
589    def test_soft_invalidate_event_no_exception(self):
590        p, canary = self._soft_invalidate_event_fixture()
591
592        c1 = p.connect()
593        c1.close()
594        assert not canary.called
595        c1 = p.connect()
596        dbapi_con = c1.connection
597        c1.invalidate(soft=True)
598        assert canary.call_args_list[0][0][0] is dbapi_con
599        assert canary.call_args_list[0][0][2] is None
600
601    def test_soft_invalidate_event_exception(self):
602        p, canary = self._soft_invalidate_event_fixture()
603
604        c1 = p.connect()
605        c1.close()
606        assert not canary.called
607        c1 = p.connect()
608        dbapi_con = c1.connection
609        exc = Exception("hi")
610        c1.invalidate(exc, soft=True)
611        assert canary.call_args_list[0][0][0] is dbapi_con
612        assert canary.call_args_list[0][0][2] is exc
613
614    def test_invalidate_event_no_exception(self):
615        p, canary = self._invalidate_event_fixture()
616
617        c1 = p.connect()
618        c1.close()
619        assert not canary.called
620        c1 = p.connect()
621        dbapi_con = c1.connection
622        c1.invalidate()
623        assert canary.call_args_list[0][0][0] is dbapi_con
624        assert canary.call_args_list[0][0][2] is None
625
626    def test_invalidate_event_exception(self):
627        p, canary = self._invalidate_event_fixture()
628
629        c1 = p.connect()
630        c1.close()
631        assert not canary.called
632        c1 = p.connect()
633        dbapi_con = c1.connection
634        exc = Exception("hi")
635        c1.invalidate(exc)
636        assert canary.call_args_list[0][0][0] is dbapi_con
637        assert canary.call_args_list[0][0][2] is exc
638
639    def test_checkin_event_gc(self):
640        p, canary = self._checkin_event_fixture()
641
642        c1 = p.connect()
643        eq_(canary, [])
644        del c1
645        lazy_gc()
646        eq_(canary, ['checkin'])
647
648    def test_checkin_event_on_subsequently_recreated(self):
649        p, canary = self._checkin_event_fixture()
650
651        c1 = p.connect()
652        p2 = p.recreate()
653        c2 = p2.connect()
654
655        eq_(canary, [])
656
657        c1.close()
658        eq_(canary, ['checkin'])
659
660        c2.close()
661        eq_(canary, ['checkin', 'checkin'])
662
663    def test_listen_targets_scope(self):
664        canary = []
665
666        def listen_one(*args):
667            canary.append("listen_one")
668
669        def listen_two(*args):
670            canary.append("listen_two")
671
672        def listen_three(*args):
673            canary.append("listen_three")
674
675        def listen_four(*args):
676            canary.append("listen_four")
677
678        engine = testing_engine(testing.db.url)
679        event.listen(pool.Pool, 'connect', listen_one)
680        event.listen(engine.pool, 'connect', listen_two)
681        event.listen(engine, 'connect', listen_three)
682        event.listen(engine.__class__, 'connect', listen_four)
683
684        engine.execute(select([1])).close()
685        eq_(
686            canary,
687            ["listen_one", "listen_four", "listen_two", "listen_three"]
688        )
689
690    def test_listen_targets_per_subclass(self):
691        """test that listen() called on a subclass remains specific to
692        that subclass."""
693
694        canary = []
695
696        def listen_one(*args):
697            canary.append("listen_one")
698
699        def listen_two(*args):
700            canary.append("listen_two")
701
702        def listen_three(*args):
703            canary.append("listen_three")
704
705        event.listen(pool.Pool, 'connect', listen_one)
706        event.listen(pool.QueuePool, 'connect', listen_two)
707        event.listen(pool.SingletonThreadPool, 'connect', listen_three)
708
709        p1 = pool.QueuePool(creator=MockDBAPI().connect)
710        p2 = pool.SingletonThreadPool(creator=MockDBAPI().connect)
711
712        assert listen_one in p1.dispatch.connect
713        assert listen_two in p1.dispatch.connect
714        assert listen_three not in p1.dispatch.connect
715        assert listen_one in p2.dispatch.connect
716        assert listen_two not in p2.dispatch.connect
717        assert listen_three in p2.dispatch.connect
718
719        p1.connect()
720        eq_(canary, ["listen_one", "listen_two"])
721        p2.connect()
722        eq_(canary, ["listen_one", "listen_two", "listen_one", "listen_three"])
723
724    def teardown(self):
725        # TODO: need to get remove() functionality
726        # going
727        pool.Pool.dispatch._clear()
728
729
730class PoolFirstConnectSyncTest(PoolTestBase):
731    # test [ticket:2964]
732
733    @testing.requires.timing_intensive
734    def test_sync(self):
735        pool = self._queuepool_fixture(pool_size=3, max_overflow=0)
736
737        evt = Mock()
738
739        @event.listens_for(pool, 'first_connect')
740        def slow_first_connect(dbapi_con, rec):
741            time.sleep(1)
742            evt.first_connect()
743
744        @event.listens_for(pool, 'connect')
745        def on_connect(dbapi_con, rec):
746            evt.connect()
747
748        def checkout():
749            for j in range(2):
750                c1 = pool.connect()
751                time.sleep(.02)
752                c1.close()
753                time.sleep(.02)
754
755        threads = []
756        for i in range(5):
757            th = threading.Thread(target=checkout)
758            th.start()
759            threads.append(th)
760        for th in threads:
761            th.join(join_timeout)
762
763        eq_(
764            evt.mock_calls,
765            [
766                call.first_connect(),
767                call.connect(),
768                call.connect(),
769                call.connect()]
770        )
771
772
773class DeprecatedPoolListenerTest(PoolTestBase):
774    @testing.requires.predictable_gc
775    @testing.uses_deprecated(r".*Use event.listen")
776    def test_listeners(self):
777
778        class InstrumentingListener(object):
779            def __init__(self):
780                if hasattr(self, 'connect'):
781                    self.connect = self.inst_connect
782                if hasattr(self, 'first_connect'):
783                    self.first_connect = self.inst_first_connect
784                if hasattr(self, 'checkout'):
785                    self.checkout = self.inst_checkout
786                if hasattr(self, 'checkin'):
787                    self.checkin = self.inst_checkin
788                self.clear()
789
790            def clear(self):
791                self.connected = []
792                self.first_connected = []
793                self.checked_out = []
794                self.checked_in = []
795
796            def assert_total(self, conn, fconn, cout, cin):
797                eq_(len(self.connected), conn)
798                eq_(len(self.first_connected), fconn)
799                eq_(len(self.checked_out), cout)
800                eq_(len(self.checked_in), cin)
801
802            def assert_in(
803                    self, item, in_conn, in_fconn,
804                    in_cout, in_cin):
805                eq_((item in self.connected), in_conn)
806                eq_((item in self.first_connected), in_fconn)
807                eq_((item in self.checked_out), in_cout)
808                eq_((item in self.checked_in), in_cin)
809
810            def inst_connect(self, con, record):
811                print("connect(%s, %s)" % (con, record))
812                assert con is not None
813                assert record is not None
814                self.connected.append(con)
815
816            def inst_first_connect(self, con, record):
817                print("first_connect(%s, %s)" % (con, record))
818                assert con is not None
819                assert record is not None
820                self.first_connected.append(con)
821
822            def inst_checkout(self, con, record, proxy):
823                print("checkout(%s, %s, %s)" % (con, record, proxy))
824                assert con is not None
825                assert record is not None
826                assert proxy is not None
827                self.checked_out.append(con)
828
829            def inst_checkin(self, con, record):
830                print("checkin(%s, %s)" % (con, record))
831                # con can be None if invalidated
832                assert record is not None
833                self.checked_in.append(con)
834
835        class ListenAll(tsa.interfaces.PoolListener, InstrumentingListener):
836            pass
837
838        class ListenConnect(InstrumentingListener):
839            def connect(self, con, record):
840                pass
841
842        class ListenFirstConnect(InstrumentingListener):
843            def first_connect(self, con, record):
844                pass
845
846        class ListenCheckOut(InstrumentingListener):
847            def checkout(self, con, record, proxy, num):
848                pass
849
850        class ListenCheckIn(InstrumentingListener):
851            def checkin(self, con, record):
852                pass
853
854        def assert_listeners(p, total, conn, fconn, cout, cin):
855            for instance in (p, p.recreate()):
856                self.assert_(len(instance.dispatch.connect) == conn)
857                self.assert_(len(instance.dispatch.first_connect) == fconn)
858                self.assert_(len(instance.dispatch.checkout) == cout)
859                self.assert_(len(instance.dispatch.checkin) == cin)
860
861        p = self._queuepool_fixture()
862        assert_listeners(p, 0, 0, 0, 0, 0)
863
864        p.add_listener(ListenAll())
865        assert_listeners(p, 1, 1, 1, 1, 1)
866
867        p.add_listener(ListenConnect())
868        assert_listeners(p, 2, 2, 1, 1, 1)
869
870        p.add_listener(ListenFirstConnect())
871        assert_listeners(p, 3, 2, 2, 1, 1)
872
873        p.add_listener(ListenCheckOut())
874        assert_listeners(p, 4, 2, 2, 2, 1)
875
876        p.add_listener(ListenCheckIn())
877        assert_listeners(p, 5, 2, 2, 2, 2)
878        del p
879
880        snoop = ListenAll()
881        p = self._queuepool_fixture(listeners=[snoop])
882        assert_listeners(p, 1, 1, 1, 1, 1)
883
884        c = p.connect()
885        snoop.assert_total(1, 1, 1, 0)
886        cc = c.connection
887        snoop.assert_in(cc, True, True, True, False)
888        c.close()
889        snoop.assert_in(cc, True, True, True, True)
890        del c, cc
891
892        snoop.clear()
893
894        # this one depends on immediate gc
895        c = p.connect()
896        cc = c.connection
897        snoop.assert_in(cc, False, False, True, False)
898        snoop.assert_total(0, 0, 1, 0)
899        del c, cc
900        lazy_gc()
901        snoop.assert_total(0, 0, 1, 1)
902
903        p.dispose()
904        snoop.clear()
905
906        c = p.connect()
907        c.close()
908        c = p.connect()
909        snoop.assert_total(1, 0, 2, 1)
910        c.close()
911        snoop.assert_total(1, 0, 2, 2)
912
913        # invalidation
914        p.dispose()
915        snoop.clear()
916
917        c = p.connect()
918        snoop.assert_total(1, 0, 1, 0)
919        c.invalidate()
920        snoop.assert_total(1, 0, 1, 1)
921        c.close()
922        snoop.assert_total(1, 0, 1, 1)
923        del c
924        lazy_gc()
925        snoop.assert_total(1, 0, 1, 1)
926        c = p.connect()
927        snoop.assert_total(2, 0, 2, 1)
928        c.close()
929        del c
930        lazy_gc()
931        snoop.assert_total(2, 0, 2, 2)
932
933        # detached
934        p.dispose()
935        snoop.clear()
936
937        c = p.connect()
938        snoop.assert_total(1, 0, 1, 0)
939        c.detach()
940        snoop.assert_total(1, 0, 1, 0)
941        c.close()
942        del c
943        snoop.assert_total(1, 0, 1, 0)
944        c = p.connect()
945        snoop.assert_total(2, 0, 2, 0)
946        c.close()
947        del c
948        snoop.assert_total(2, 0, 2, 1)
949
950        # recreated
951        p = p.recreate()
952        snoop.clear()
953
954        c = p.connect()
955        snoop.assert_total(1, 1, 1, 0)
956        c.close()
957        snoop.assert_total(1, 1, 1, 1)
958        c = p.connect()
959        snoop.assert_total(1, 1, 2, 1)
960        c.close()
961        snoop.assert_total(1, 1, 2, 2)
962
963    @testing.uses_deprecated(r".*Use event.listen")
964    def test_listeners_callables(self):
965        def connect(dbapi_con, con_record):
966            counts[0] += 1
967
968        def checkout(dbapi_con, con_record, con_proxy):
969            counts[1] += 1
970
971        def checkin(dbapi_con, con_record):
972            counts[2] += 1
973
974        i_all = dict(connect=connect, checkout=checkout, checkin=checkin)
975        i_connect = dict(connect=connect)
976        i_checkout = dict(checkout=checkout)
977        i_checkin = dict(checkin=checkin)
978
979        for cls in (pool.QueuePool, pool.StaticPool):
980            counts = [0, 0, 0]
981
982            def assert_listeners(p, total, conn, cout, cin):
983                for instance in (p, p.recreate()):
984                    eq_(len(instance.dispatch.connect), conn)
985                    eq_(len(instance.dispatch.checkout), cout)
986                    eq_(len(instance.dispatch.checkin), cin)
987
988            p = self._queuepool_fixture()
989            assert_listeners(p, 0, 0, 0, 0)
990
991            p.add_listener(i_all)
992            assert_listeners(p, 1, 1, 1, 1)
993
994            p.add_listener(i_connect)
995            assert_listeners(p, 2, 1, 1, 1)
996
997            p.add_listener(i_checkout)
998            assert_listeners(p, 3, 1, 1, 1)
999
1000            p.add_listener(i_checkin)
1001            assert_listeners(p, 4, 1, 1, 1)
1002            del p
1003
1004            p = self._queuepool_fixture(listeners=[i_all])
1005            assert_listeners(p, 1, 1, 1, 1)
1006
1007            c = p.connect()
1008            assert counts == [1, 1, 0]
1009            c.close()
1010            assert counts == [1, 1, 1]
1011
1012            c = p.connect()
1013            assert counts == [1, 2, 1]
1014            p.add_listener(i_checkin)
1015            c.close()
1016            assert counts == [1, 2, 2]
1017
1018
1019class QueuePoolTest(PoolTestBase):
1020
1021    def test_queuepool_del(self):
1022        self._do_testqueuepool(useclose=False)
1023
1024    def test_queuepool_close(self):
1025        self._do_testqueuepool(useclose=True)
1026
1027    def _do_testqueuepool(self, useclose=False):
1028        p = self._queuepool_fixture(
1029            pool_size=3,
1030            max_overflow=-1)
1031
1032        def status(pool):
1033            return pool.size(), pool.checkedin(), pool.overflow(), \
1034                pool.checkedout()
1035
1036        c1 = p.connect()
1037        self.assert_(status(p) == (3, 0, -2, 1))
1038        c2 = p.connect()
1039        self.assert_(status(p) == (3, 0, -1, 2))
1040        c3 = p.connect()
1041        self.assert_(status(p) == (3, 0, 0, 3))
1042        c4 = p.connect()
1043        self.assert_(status(p) == (3, 0, 1, 4))
1044        c5 = p.connect()
1045        self.assert_(status(p) == (3, 0, 2, 5))
1046        c6 = p.connect()
1047        self.assert_(status(p) == (3, 0, 3, 6))
1048        if useclose:
1049            c4.close()
1050            c3.close()
1051            c2.close()
1052        else:
1053            c4 = c3 = c2 = None
1054            lazy_gc()
1055        self.assert_(status(p) == (3, 3, 3, 3))
1056        if useclose:
1057            c1.close()
1058            c5.close()
1059            c6.close()
1060        else:
1061            c1 = c5 = c6 = None
1062            lazy_gc()
1063        self.assert_(status(p) == (3, 3, 0, 0))
1064        c1 = p.connect()
1065        c2 = p.connect()
1066        self.assert_(status(p) == (3, 1, 0, 2), status(p))
1067        if useclose:
1068            c2.close()
1069        else:
1070            c2 = None
1071            lazy_gc()
1072        self.assert_(status(p) == (3, 2, 0, 1))
1073        c1.close()
1074        lazy_gc()
1075        assert not pool._refs
1076
1077    @testing.requires.timing_intensive
1078    def test_timeout(self):
1079        p = self._queuepool_fixture(
1080            pool_size=3,
1081            max_overflow=0,
1082            timeout=2)
1083        c1 = p.connect()  # noqa
1084        c2 = p.connect()  # noqa
1085        c3 = p.connect()  # noqa
1086        now = time.time()
1087
1088        assert_raises(
1089            tsa.exc.TimeoutError,
1090            p.connect
1091        )
1092        assert int(time.time() - now) == 2
1093
1094    @testing.requires.threading_with_mock
1095    @testing.requires.timing_intensive
1096    def test_timeout_race(self):
1097        # test a race condition where the initial connecting threads all race
1098        # to queue.Empty, then block on the mutex.  each thread consumes a
1099        # connection as they go in.  when the limit is reached, the remaining
1100        # threads go in, and get TimeoutError; even though they never got to
1101        # wait for the timeout on queue.get().  the fix involves checking the
1102        # timeout again within the mutex, and if so, unlocking and throwing
1103        # them back to the start of do_get()
1104        dbapi = MockDBAPI()
1105        p = pool.QueuePool(
1106            creator=lambda: dbapi.connect(delay=.05),
1107            pool_size=2,
1108            max_overflow=1, use_threadlocal=False, timeout=3)
1109        timeouts = []
1110
1111        def checkout():
1112            for x in range(1):
1113                now = time.time()
1114                try:
1115                    c1 = p.connect()
1116                except tsa.exc.TimeoutError:
1117                    timeouts.append(time.time() - now)
1118                    continue
1119                time.sleep(4)
1120                c1.close()
1121
1122        threads = []
1123        for i in range(10):
1124            th = threading.Thread(target=checkout)
1125            th.start()
1126            threads.append(th)
1127        for th in threads:
1128            th.join(join_timeout)
1129
1130        assert len(timeouts) > 0
1131        for t in timeouts:
1132            assert t >= 3, "Not all timeouts were >= 3 seconds %r" % timeouts
1133            # normally, the timeout should under 4 seconds,
1134            # but on a loaded down buildbot it can go up.
1135            assert t < 14, "Not all timeouts were < 14 seconds %r" % timeouts
1136
1137    def _test_overflow(self, thread_count, max_overflow):
1138        gc_collect()
1139
1140        dbapi = MockDBAPI()
1141        mutex = threading.Lock()
1142
1143        def creator():
1144            time.sleep(.05)
1145            with mutex:
1146                return dbapi.connect()
1147
1148        p = pool.QueuePool(creator=creator,
1149                           pool_size=3, timeout=2,
1150                           max_overflow=max_overflow)
1151        peaks = []
1152
1153        def whammy():
1154            for i in range(10):
1155                try:
1156                    con = p.connect()
1157                    time.sleep(.005)
1158                    peaks.append(p.overflow())
1159                    con.close()
1160                    del con
1161                except tsa.exc.TimeoutError:
1162                    pass
1163        threads = []
1164        for i in range(thread_count):
1165            th = threading.Thread(target=whammy)
1166            th.start()
1167            threads.append(th)
1168        for th in threads:
1169            th.join(join_timeout)
1170
1171        self.assert_(max(peaks) <= max_overflow)
1172
1173        lazy_gc()
1174        assert not pool._refs
1175
1176    def test_overflow_reset_on_failed_connect(self):
1177        dbapi = Mock()
1178
1179        def failing_dbapi():
1180            time.sleep(2)
1181            raise Exception("connection failed")
1182
1183        creator = dbapi.connect
1184
1185        def create():
1186            return creator()
1187
1188        p = pool.QueuePool(creator=create, pool_size=2, max_overflow=3)
1189        c1 = self._with_teardown(p.connect())  # noqa
1190        c2 = self._with_teardown(p.connect())  # noqa
1191        c3 = self._with_teardown(p.connect())  # noqa
1192        eq_(p._overflow, 1)
1193        creator = failing_dbapi
1194        assert_raises(Exception, p.connect)
1195        eq_(p._overflow, 1)
1196
1197    @testing.requires.threading_with_mock
1198    @testing.requires.timing_intensive
1199    def test_hanging_connect_within_overflow(self):
1200        """test that a single connect() call which is hanging
1201        does not block other connections from proceeding."""
1202
1203        dbapi = Mock()
1204        mutex = threading.Lock()
1205
1206        def hanging_dbapi():
1207            time.sleep(2)
1208            with mutex:
1209                return dbapi.connect()
1210
1211        def fast_dbapi():
1212            with mutex:
1213                return dbapi.connect()
1214
1215        creator = threading.local()
1216
1217        def create():
1218            return creator.mock_connector()
1219
1220        def run_test(name, pool, should_hang):
1221            if should_hang:
1222                creator.mock_connector = hanging_dbapi
1223            else:
1224                creator.mock_connector = fast_dbapi
1225
1226            conn = pool.connect()
1227            conn.operation(name)
1228            time.sleep(1)
1229            conn.close()
1230
1231        p = pool.QueuePool(creator=create, pool_size=2, max_overflow=3)
1232
1233        threads = [
1234            threading.Thread(
1235                target=run_test, args=("success_one", p, False)),
1236            threading.Thread(
1237                target=run_test, args=("success_two", p, False)),
1238            threading.Thread(
1239                target=run_test, args=("overflow_one", p, True)),
1240            threading.Thread(
1241                target=run_test, args=("overflow_two", p, False)),
1242            threading.Thread(
1243                target=run_test, args=("overflow_three", p, False))
1244        ]
1245        for t in threads:
1246            t.start()
1247            time.sleep(.2)
1248
1249        for t in threads:
1250            t.join(timeout=join_timeout)
1251        eq_(
1252            dbapi.connect().operation.mock_calls,
1253            [call("success_one"), call("success_two"),
1254                call("overflow_two"), call("overflow_three"),
1255                call("overflow_one")]
1256        )
1257
1258    @testing.requires.threading_with_mock
1259    @testing.requires.timing_intensive
1260    def test_waiters_handled(self):
1261        """test that threads waiting for connections are
1262        handled when the pool is replaced.
1263
1264        """
1265        mutex = threading.Lock()
1266        dbapi = MockDBAPI()
1267
1268        def creator():
1269            mutex.acquire()
1270            try:
1271                return dbapi.connect()
1272            finally:
1273                mutex.release()
1274
1275        success = []
1276        for timeout in (None, 30):
1277            for max_overflow in (0, -1, 3):
1278                p = pool.QueuePool(creator=creator,
1279                                   pool_size=2, timeout=timeout,
1280                                   max_overflow=max_overflow)
1281
1282                def waiter(p, timeout, max_overflow):
1283                    success_key = (timeout, max_overflow)
1284                    conn = p.connect()
1285                    success.append(success_key)
1286                    time.sleep(.1)
1287                    conn.close()
1288
1289                c1 = p.connect()  # noqa
1290                c2 = p.connect()
1291
1292                threads = []
1293                for i in range(2):
1294                    t = threading.Thread(
1295                        target=waiter,
1296                        args=(p, timeout, max_overflow))
1297                    t.daemon = True
1298                    t.start()
1299                    threads.append(t)
1300
1301                # this sleep makes sure that the
1302                # two waiter threads hit upon wait()
1303                # inside the queue, before we invalidate the other
1304                # two conns
1305                time.sleep(.2)
1306                p._invalidate(c2)
1307
1308                for t in threads:
1309                    t.join(join_timeout)
1310
1311        eq_(len(success), 12, "successes: %s" % success)
1312
1313    def test_connrec_invalidated_within_checkout_no_race(self):
1314        """Test that a concurrent ConnectionRecord.invalidate() which
1315        occurs after the ConnectionFairy has called
1316        _ConnectionRecord.checkout()
1317        but before the ConnectionFairy tests "fairy.connection is None"
1318        will not result in an InvalidRequestError.
1319
1320        This use case assumes that a listener on the checkout() event
1321        will be raising DisconnectionError so that a reconnect attempt
1322        may occur.
1323
1324        """
1325        dbapi = MockDBAPI()
1326
1327        def creator():
1328            return dbapi.connect()
1329
1330        p = pool.QueuePool(creator=creator, pool_size=1, max_overflow=0)
1331
1332        conn = p.connect()
1333        conn.close()
1334
1335        _existing_checkout = pool._ConnectionRecord.checkout
1336
1337        @classmethod
1338        def _decorate_existing_checkout(cls, *arg, **kw):
1339            fairy = _existing_checkout(*arg, **kw)
1340            connrec = fairy._connection_record
1341            connrec.invalidate()
1342            return fairy
1343
1344        with patch(
1345                "sqlalchemy.pool._ConnectionRecord.checkout",
1346                _decorate_existing_checkout):
1347            conn = p.connect()
1348            is_(conn._connection_record.connection, None)
1349        conn.close()
1350
1351    @testing.requires.threading_with_mock
1352    @testing.requires.timing_intensive
1353    def test_notify_waiters(self):
1354        dbapi = MockDBAPI()
1355
1356        canary = []
1357
1358        def creator():
1359            canary.append(1)
1360            return dbapi.connect()
1361        p1 = pool.QueuePool(
1362            creator=creator,
1363            pool_size=1, timeout=None,
1364            max_overflow=0)
1365
1366        def waiter(p):
1367            conn = p.connect()
1368            canary.append(2)
1369            time.sleep(.5)
1370            conn.close()
1371
1372        c1 = p1.connect()
1373
1374        threads = []
1375        for i in range(5):
1376            t = threading.Thread(target=waiter, args=(p1, ))
1377            t.start()
1378            threads.append(t)
1379        time.sleep(.5)
1380        eq_(canary, [1])
1381
1382        # this also calls invalidate()
1383        # on c1
1384        p1._invalidate(c1)
1385
1386        for t in threads:
1387            t.join(join_timeout)
1388
1389        eq_(canary, [1, 1, 2, 2, 2, 2, 2])
1390
1391    def test_dispose_closes_pooled(self):
1392        dbapi = MockDBAPI()
1393
1394        p = pool.QueuePool(creator=dbapi.connect,
1395                           pool_size=2, timeout=None,
1396                           max_overflow=0)
1397        c1 = p.connect()
1398        c2 = p.connect()
1399        c1_con = c1.connection
1400        c2_con = c2.connection
1401
1402        c1.close()
1403
1404        eq_(c1_con.close.call_count, 0)
1405        eq_(c2_con.close.call_count, 0)
1406
1407        p.dispose()
1408
1409        eq_(c1_con.close.call_count, 1)
1410        eq_(c2_con.close.call_count, 0)
1411
1412        # currently, if a ConnectionFairy is closed
1413        # after the pool has been disposed, there's no
1414        # flag that states it should be invalidated
1415        # immediately - it just gets returned to the
1416        # pool normally...
1417        c2.close()
1418        eq_(c1_con.close.call_count, 1)
1419        eq_(c2_con.close.call_count, 0)
1420
1421        # ...and that's the one we'll get back next.
1422        c3 = p.connect()
1423        assert c3.connection is c2_con
1424
1425    @testing.requires.threading_with_mock
1426    @testing.requires.timing_intensive
1427    def test_no_overflow(self):
1428        self._test_overflow(40, 0)
1429
1430    @testing.requires.threading_with_mock
1431    @testing.requires.timing_intensive
1432    def test_max_overflow(self):
1433        self._test_overflow(40, 5)
1434
1435    def test_mixed_close(self):
1436        pool._refs.clear()
1437        p = self._queuepool_fixture(pool_size=3, max_overflow=-1,
1438                                    use_threadlocal=True)
1439        c1 = p.connect()
1440        c2 = p.connect()
1441        assert c1 is c2
1442        c1.close()
1443        c2 = None
1444        assert p.checkedout() == 1
1445        c1 = None
1446        lazy_gc()
1447        assert p.checkedout() == 0
1448        lazy_gc()
1449        assert not pool._refs
1450
1451    def test_overflow_no_gc_tlocal(self):
1452        self._test_overflow_no_gc(True)
1453
1454    def test_overflow_no_gc(self):
1455        self._test_overflow_no_gc(False)
1456
1457    def _test_overflow_no_gc(self, threadlocal):
1458        p = self._queuepool_fixture(
1459            pool_size=2,
1460            max_overflow=2)
1461
1462        # disable weakref collection of the
1463        # underlying connections
1464        strong_refs = set()
1465
1466        def _conn():
1467            c = p.connect()
1468            strong_refs.add(c.connection)
1469            return c
1470
1471        for j in range(5):
1472            # open 4 conns at a time.  each time this
1473            # will yield two pooled connections + two
1474            # overflow connections.
1475            conns = [_conn() for i in range(4)]
1476            for c in conns:
1477                c.close()
1478
1479        # doing that for a total of 5 times yields
1480        # ten overflow connections closed plus the
1481        # two pooled connections unclosed.
1482
1483        eq_(
1484            set([c.close.call_count for c in strong_refs]),
1485            set([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0])
1486        )
1487
1488    @testing.requires.predictable_gc
1489    def test_weakref_kaboom(self):
1490        p = self._queuepool_fixture(
1491            pool_size=3,
1492            max_overflow=-1, use_threadlocal=True)
1493        c1 = p.connect()
1494        c2 = p.connect()
1495        c1.close()
1496        c2 = None
1497        del c1
1498        del c2
1499        gc_collect()
1500        assert p.checkedout() == 0
1501        c3 = p.connect()
1502        assert c3 is not None
1503
1504    def test_trick_the_counter(self):
1505        """this is a "flaw" in the connection pool; since threadlocal
1506        uses a single ConnectionFairy per thread with an open/close
1507        counter, you can fool the counter into giving you a
1508        ConnectionFairy with an ambiguous counter.  i.e. its not true
1509        reference counting."""
1510
1511        p = self._queuepool_fixture(
1512            pool_size=3,
1513            max_overflow=-1, use_threadlocal=True)
1514        c1 = p.connect()
1515        c2 = p.connect()
1516        assert c1 is c2
1517        c1.close()
1518        c2 = p.connect()
1519        c2.close()
1520        self.assert_(p.checkedout() != 0)
1521        c2.close()
1522        self.assert_(p.checkedout() == 0)
1523
1524    def test_recycle(self):
1525        with patch("sqlalchemy.pool.time.time") as mock:
1526            mock.return_value = 10000
1527
1528            p = self._queuepool_fixture(
1529                pool_size=1,
1530                max_overflow=0,
1531                recycle=30)
1532            c1 = p.connect()
1533            c_ref = weakref.ref(c1.connection)
1534            c1.close()
1535            mock.return_value = 10001
1536            c2 = p.connect()
1537
1538            is_(c2.connection, c_ref())
1539            c2.close()
1540
1541            mock.return_value = 10035
1542            c3 = p.connect()
1543            is_not_(c3.connection, c_ref())
1544
1545    @testing.requires.timing_intensive
1546    def test_recycle_on_invalidate(self):
1547        p = self._queuepool_fixture(
1548            pool_size=1,
1549            max_overflow=0)
1550        c1 = p.connect()
1551        c_ref = weakref.ref(c1.connection)
1552        c1.close()
1553        c2 = p.connect()
1554        is_(c2.connection, c_ref())
1555
1556        c2_rec = c2._connection_record
1557        p._invalidate(c2)
1558        assert c2_rec.connection is None
1559        c2.close()
1560        time.sleep(.5)
1561        c3 = p.connect()
1562
1563        is_not_(c3.connection, c_ref())
1564
1565    @testing.requires.timing_intensive
1566    def test_recycle_on_soft_invalidate(self):
1567        p = self._queuepool_fixture(
1568            pool_size=1,
1569            max_overflow=0)
1570        c1 = p.connect()
1571        c_ref = weakref.ref(c1.connection)
1572        c1.close()
1573        c2 = p.connect()
1574        is_(c2.connection, c_ref())
1575
1576        c2_rec = c2._connection_record
1577        c2.invalidate(soft=True)
1578        is_(c2_rec.connection, c2.connection)
1579
1580        c2.close()
1581        time.sleep(.5)
1582        c3 = p.connect()
1583        is_not_(c3.connection, c_ref())
1584        is_(c3._connection_record, c2_rec)
1585        is_(c2_rec.connection, c3.connection)
1586
1587    def _no_wr_finalize(self):
1588        finalize_fairy = pool._finalize_fairy
1589
1590        def assert_no_wr_callback(
1591            connection, connection_record,
1592                pool, ref, echo, fairy=None):
1593            if fairy is None:
1594                raise AssertionError(
1595                    "finalize fairy was called as a weakref callback")
1596            return finalize_fairy(
1597                connection, connection_record, pool, ref, echo, fairy)
1598        return patch.object(
1599            pool, '_finalize_fairy', assert_no_wr_callback)
1600
1601    def _assert_cleanup_on_pooled_reconnect(self, dbapi, p):
1602        # p is QueuePool with size=1, max_overflow=2,
1603        # and one connection in the pool that will need to
1604        # reconnect when next used (either due to recycle or invalidate)
1605
1606        with self._no_wr_finalize():
1607            eq_(p.checkedout(), 0)
1608            eq_(p._overflow, 0)
1609            dbapi.shutdown(True)
1610            assert_raises(
1611                Exception,
1612                p.connect
1613            )
1614            eq_(p._overflow, 0)
1615            eq_(p.checkedout(), 0)  # and not 1
1616
1617            dbapi.shutdown(False)
1618
1619            c1 = self._with_teardown(p.connect())  # noqa
1620            assert p._pool.empty()  # poolsize is one, so we're empty OK
1621            c2 = self._with_teardown(p.connect())  # noqa
1622            eq_(p._overflow, 1)  # and not 2
1623
1624            # this hangs if p._overflow is 2
1625            c3 = self._with_teardown(p.connect())
1626
1627            c3.close()
1628
1629    def test_error_on_pooled_reconnect_cleanup_invalidate(self):
1630        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=2)
1631        c1 = p.connect()
1632        c1.invalidate()
1633        c1.close()
1634        self._assert_cleanup_on_pooled_reconnect(dbapi, p)
1635
1636    @testing.requires.timing_intensive
1637    def test_error_on_pooled_reconnect_cleanup_recycle(self):
1638        dbapi, p = self._queuepool_dbapi_fixture(
1639            pool_size=1,
1640            max_overflow=2, recycle=1)
1641        c1 = p.connect()
1642        c1.close()
1643        time.sleep(1.5)
1644        self._assert_cleanup_on_pooled_reconnect(dbapi, p)
1645
1646    def test_connect_handler_not_called_for_recycled(self):
1647        """test [ticket:3497]"""
1648
1649        dbapi, p = self._queuepool_dbapi_fixture(
1650            pool_size=2, max_overflow=2)
1651
1652        canary = Mock()
1653
1654        c1 = p.connect()
1655        c2 = p.connect()
1656
1657        c1.close()
1658        c2.close()
1659
1660        dbapi.shutdown(True)
1661
1662        bad = p.connect()
1663        p._invalidate(bad)
1664        bad.close()
1665        assert p._invalidate_time
1666
1667        event.listen(p, "connect", canary.connect)
1668        event.listen(p, "checkout", canary.checkout)
1669
1670        assert_raises(
1671            Exception,
1672            p.connect
1673        )
1674
1675        p._pool.queue = collections.deque(
1676            [
1677                c for c in p._pool.queue
1678                if c.connection is not None
1679            ]
1680        )
1681
1682        dbapi.shutdown(False)
1683        c = p.connect()
1684        c.close()
1685
1686        eq_(
1687            canary.mock_calls,
1688            [
1689                call.connect(ANY, ANY),
1690                call.checkout(ANY, ANY, ANY)
1691            ]
1692        )
1693
1694    def test_connect_checkout_handler_always_gets_info(self):
1695        """test [ticket:3497]"""
1696
1697        dbapi, p = self._queuepool_dbapi_fixture(
1698            pool_size=2, max_overflow=2)
1699
1700        c1 = p.connect()
1701        c2 = p.connect()
1702
1703        c1.close()
1704        c2.close()
1705
1706        dbapi.shutdown(True)
1707
1708        bad = p.connect()
1709        p._invalidate(bad)
1710        bad.close()
1711        assert p._invalidate_time
1712
1713        @event.listens_for(p, "connect")
1714        def connect(conn, conn_rec):
1715            conn_rec.info['x'] = True
1716
1717        @event.listens_for(p, "checkout")
1718        def checkout(conn, conn_rec, conn_f):
1719            assert 'x' in conn_rec.info
1720
1721        assert_raises(
1722            Exception,
1723            p.connect
1724        )
1725
1726        p._pool.queue = collections.deque(
1727            [
1728                c for c in p._pool.queue
1729                if c.connection is not None
1730            ]
1731        )
1732
1733        dbapi.shutdown(False)
1734        c = p.connect()
1735        c.close()
1736
1737    def test_error_on_pooled_reconnect_cleanup_wcheckout_event(self):
1738        dbapi, p = self._queuepool_dbapi_fixture(
1739            pool_size=1,
1740            max_overflow=2)
1741
1742        c1 = p.connect()
1743        c1.close()
1744
1745        @event.listens_for(p, "checkout")
1746        def handle_checkout_event(dbapi_con, con_record, con_proxy):
1747            if dbapi.is_shutdown:
1748                raise tsa.exc.DisconnectionError()
1749
1750        self._assert_cleanup_on_pooled_reconnect(dbapi, p)
1751
1752    @testing.requires.predictable_gc
1753    def test_userspace_disconnectionerror_weakref_finalizer(self):
1754        dbapi, pool = self._queuepool_dbapi_fixture(
1755            pool_size=1,
1756            max_overflow=2)
1757
1758        @event.listens_for(pool, "checkout")
1759        def handle_checkout_event(dbapi_con, con_record, con_proxy):
1760            if getattr(dbapi_con, 'boom') == 'yes':
1761                raise tsa.exc.DisconnectionError()
1762
1763        conn = pool.connect()
1764        old_dbapi_conn = conn.connection
1765        conn.close()
1766
1767        eq_(old_dbapi_conn.mock_calls, [call.rollback()])
1768
1769        old_dbapi_conn.boom = 'yes'
1770
1771        conn = pool.connect()
1772        dbapi_conn = conn.connection
1773        del conn
1774        gc_collect()
1775
1776        # new connection was reset on return appropriately
1777        eq_(dbapi_conn.mock_calls, [call.rollback()])
1778
1779        # old connection was just closed - did not get an
1780        # erroneous reset on return
1781        eq_(
1782            old_dbapi_conn.mock_calls,
1783            [call.rollback(), call.close()]
1784        )
1785
1786    @testing.requires.timing_intensive
1787    def test_recycle_pool_no_race(self):
1788        def slow_close():
1789            slow_closing_connection._slow_close()
1790            time.sleep(.5)
1791
1792        slow_closing_connection = Mock()
1793        slow_closing_connection.connect.return_value.close = slow_close
1794
1795        class Error(Exception):
1796            pass
1797
1798        dialect = Mock()
1799        dialect.is_disconnect = lambda *arg, **kw: True
1800        dialect.dbapi.Error = Error
1801
1802        pools = []
1803
1804        class TrackQueuePool(pool.QueuePool):
1805            def __init__(self, *arg, **kw):
1806                pools.append(self)
1807                super(TrackQueuePool, self).__init__(*arg, **kw)
1808
1809        def creator():
1810            return slow_closing_connection.connect()
1811        p1 = TrackQueuePool(creator=creator, pool_size=20)
1812
1813        from sqlalchemy import create_engine
1814        eng = create_engine(testing.db.url, pool=p1, _initialize=False)
1815        eng.dialect = dialect
1816
1817        # 15 total connections
1818        conns = [eng.connect() for i in range(15)]
1819
1820        # return 8 back to the pool
1821        for conn in conns[3:10]:
1822            conn.close()
1823
1824        def attempt(conn):
1825            time.sleep(random.random())
1826            try:
1827                conn._handle_dbapi_exception(
1828                    Error(), "statement", {},
1829                    Mock(), Mock())
1830            except tsa.exc.DBAPIError:
1831                pass
1832
1833        # run an error + invalidate operation on the remaining 7 open
1834        # connections
1835        threads = []
1836        for conn in conns:
1837            t = threading.Thread(target=attempt, args=(conn, ))
1838            t.start()
1839            threads.append(t)
1840
1841        for t in threads:
1842            t.join()
1843
1844        # return all 15 connections to the pool
1845        for conn in conns:
1846            conn.close()
1847
1848        # re-open 15 total connections
1849        conns = [eng.connect() for i in range(15)]
1850
1851        # 15 connections have been fully closed due to invalidate
1852        assert slow_closing_connection._slow_close.call_count == 15
1853
1854        # 15 initial connections + 15 reconnections
1855        assert slow_closing_connection.connect.call_count == 30
1856        assert len(pools) <= 2, len(pools)
1857
1858    def test_invalidate(self):
1859        p = self._queuepool_fixture(pool_size=1, max_overflow=0)
1860        c1 = p.connect()
1861        c_id = c1.connection.id
1862        c1.close()
1863        c1 = None
1864        c1 = p.connect()
1865        assert c1.connection.id == c_id
1866        c1.invalidate()
1867        c1 = None
1868        c1 = p.connect()
1869        assert c1.connection.id != c_id
1870
1871    def test_recreate(self):
1872        p = self._queuepool_fixture(reset_on_return=None, pool_size=1,
1873                                    max_overflow=0)
1874        p2 = p.recreate()
1875        assert p2.size() == 1
1876        assert p2._reset_on_return is pool.reset_none
1877        assert p2._use_threadlocal is False
1878        assert p2._max_overflow == 0
1879
1880    def test_reconnect(self):
1881        """tests reconnect operations at the pool level.  SA's
1882        engine/dialect includes another layer of reconnect support for
1883        'database was lost' errors."""
1884
1885        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
1886        c1 = p.connect()
1887        c_id = c1.connection.id
1888        c1.close()
1889        c1 = None
1890        c1 = p.connect()
1891        assert c1.connection.id == c_id
1892        dbapi.raise_error = True
1893        c1.invalidate()
1894        c1 = None
1895        c1 = p.connect()
1896        assert c1.connection.id != c_id
1897
1898    def test_detach(self):
1899        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
1900        c1 = p.connect()
1901        c1.detach()
1902        c2 = p.connect()  # noqa
1903        eq_(dbapi.connect.mock_calls, [call("foo.db"), call("foo.db")])
1904
1905        c1_con = c1.connection
1906        assert c1_con is not None
1907        eq_(c1_con.close.call_count, 0)
1908        c1.close()
1909        eq_(c1_con.close.call_count, 1)
1910
1911    def test_detach_via_invalidate(self):
1912        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
1913
1914        c1 = p.connect()
1915        c1_con = c1.connection
1916        c1.invalidate()
1917        assert c1.connection is None
1918        eq_(c1_con.close.call_count, 1)
1919
1920        c2 = p.connect()
1921        assert c2.connection is not c1_con
1922        c2_con = c2.connection
1923
1924        c2.close()
1925        eq_(c2_con.close.call_count, 0)
1926
1927    def test_threadfairy(self):
1928        p = self._queuepool_fixture(pool_size=3, max_overflow=-1,
1929                                    use_threadlocal=True)
1930        c1 = p.connect()
1931        c1.close()
1932        c2 = p.connect()
1933        assert c2.connection is not None
1934
1935
1936class ResetOnReturnTest(PoolTestBase):
1937    def _fixture(self, **kw):
1938        dbapi = Mock()
1939        return dbapi, pool.QueuePool(
1940            creator=lambda: dbapi.connect('foo.db'),
1941            **kw)
1942
1943    def test_plain_rollback(self):
1944        dbapi, p = self._fixture(reset_on_return='rollback')
1945
1946        c1 = p.connect()
1947        c1.close()
1948        assert dbapi.connect().rollback.called
1949        assert not dbapi.connect().commit.called
1950
1951    def test_plain_commit(self):
1952        dbapi, p = self._fixture(reset_on_return='commit')
1953
1954        c1 = p.connect()
1955        c1.close()
1956        assert not dbapi.connect().rollback.called
1957        assert dbapi.connect().commit.called
1958
1959    def test_plain_none(self):
1960        dbapi, p = self._fixture(reset_on_return=None)
1961
1962        c1 = p.connect()
1963        c1.close()
1964        assert not dbapi.connect().rollback.called
1965        assert not dbapi.connect().commit.called
1966
1967    def test_agent_rollback(self):
1968        dbapi, p = self._fixture(reset_on_return='rollback')
1969
1970        class Agent(object):
1971            def __init__(self, conn):
1972                self.conn = conn
1973
1974            def rollback(self):
1975                self.conn.special_rollback()
1976
1977            def commit(self):
1978                self.conn.special_commit()
1979
1980        c1 = p.connect()
1981        c1._reset_agent = Agent(c1)
1982        c1.close()
1983
1984        assert dbapi.connect().special_rollback.called
1985        assert not dbapi.connect().special_commit.called
1986
1987        assert not dbapi.connect().rollback.called
1988        assert not dbapi.connect().commit.called
1989
1990        c1 = p.connect()
1991        c1.close()
1992        eq_(dbapi.connect().special_rollback.call_count, 1)
1993        eq_(dbapi.connect().special_commit.call_count, 0)
1994
1995        assert dbapi.connect().rollback.called
1996        assert not dbapi.connect().commit.called
1997
1998    def test_agent_commit(self):
1999        dbapi, p = self._fixture(reset_on_return='commit')
2000
2001        class Agent(object):
2002            def __init__(self, conn):
2003                self.conn = conn
2004
2005            def rollback(self):
2006                self.conn.special_rollback()
2007
2008            def commit(self):
2009                self.conn.special_commit()
2010
2011        c1 = p.connect()
2012        c1._reset_agent = Agent(c1)
2013        c1.close()
2014        assert not dbapi.connect().special_rollback.called
2015        assert dbapi.connect().special_commit.called
2016
2017        assert not dbapi.connect().rollback.called
2018        assert not dbapi.connect().commit.called
2019
2020        c1 = p.connect()
2021        c1.close()
2022
2023        eq_(dbapi.connect().special_rollback.call_count, 0)
2024        eq_(dbapi.connect().special_commit.call_count, 1)
2025        assert not dbapi.connect().rollback.called
2026        assert dbapi.connect().commit.called
2027
2028
2029class SingletonThreadPoolTest(PoolTestBase):
2030
2031    @testing.requires.threading_with_mock
2032    def test_cleanup(self):
2033        self._test_cleanup(False)
2034
2035    @testing.requires.threading_with_mock
2036    def test_cleanup_no_gc(self):
2037        self._test_cleanup(True)
2038
2039    def _test_cleanup(self, strong_refs):
2040        """test that the pool's connections are OK after cleanup() has
2041        been called."""
2042
2043        dbapi = MockDBAPI()
2044
2045        lock = threading.Lock()
2046
2047        def creator():
2048            # the mock iterator isn't threadsafe...
2049            with lock:
2050                return dbapi.connect()
2051        p = pool.SingletonThreadPool(creator=creator, pool_size=3)
2052
2053        if strong_refs:
2054            sr = set()
2055
2056            def _conn():
2057                c = p.connect()
2058                sr.add(c.connection)
2059                return c
2060        else:
2061            def _conn():
2062                return p.connect()
2063
2064        def checkout():
2065            for x in range(10):
2066                c = _conn()
2067                assert c
2068                c.cursor()
2069                c.close()
2070                time.sleep(.1)
2071
2072        threads = []
2073        for i in range(10):
2074            th = threading.Thread(target=checkout)
2075            th.start()
2076            threads.append(th)
2077        for th in threads:
2078            th.join(join_timeout)
2079        assert len(p._all_conns) == 3
2080
2081        if strong_refs:
2082            still_opened = len([c for c in sr if not c.close.call_count])
2083            eq_(still_opened, 3)
2084
2085
2086class AssertionPoolTest(PoolTestBase):
2087    def test_connect_error(self):
2088        dbapi = MockDBAPI()
2089        p = pool.AssertionPool(creator=lambda: dbapi.connect('foo.db'))
2090        c1 = p.connect()  # noqa
2091        assert_raises(AssertionError, p.connect)
2092
2093    def test_connect_multiple(self):
2094        dbapi = MockDBAPI()
2095        p = pool.AssertionPool(creator=lambda: dbapi.connect('foo.db'))
2096        c1 = p.connect()
2097        c1.close()
2098        c2 = p.connect()
2099        c2.close()
2100
2101        c3 = p.connect()  # noqa
2102        assert_raises(AssertionError, p.connect)
2103
2104
2105class NullPoolTest(PoolTestBase):
2106    def test_reconnect(self):
2107        dbapi = MockDBAPI()
2108        p = pool.NullPool(creator=lambda: dbapi.connect('foo.db'))
2109        c1 = p.connect()
2110
2111        c1.close()
2112        c1 = None
2113
2114        c1 = p.connect()
2115        c1.invalidate()
2116        c1 = None
2117
2118        c1 = p.connect()
2119        dbapi.connect.assert_has_calls(
2120            [
2121                call('foo.db'),
2122                call('foo.db')],
2123            any_order=True)
2124
2125
2126class StaticPoolTest(PoolTestBase):
2127    def test_recreate(self):
2128        dbapi = MockDBAPI()
2129
2130        def creator():
2131            return dbapi.connect('foo.db')
2132        p = pool.StaticPool(creator)
2133        p2 = p.recreate()
2134        assert p._creator is p2._creator
2135
2136
2137class CreatorCompatibilityTest(PoolTestBase):
2138    def test_creator_callable_outside_noarg(self):
2139        e = testing_engine()
2140
2141        creator = e.pool._creator
2142        try:
2143            conn = creator()
2144        finally:
2145            conn.close()
2146
2147    def test_creator_callable_outside_witharg(self):
2148        e = testing_engine()
2149
2150        creator = e.pool._creator
2151        try:
2152            conn = creator(Mock())
2153        finally:
2154            conn.close()
2155
2156    def test_creator_patching_arg_to_noarg(self):
2157        e = testing_engine()
2158        creator = e.pool._creator
2159        try:
2160            # the creator is the two-arg form
2161            conn = creator(Mock())
2162        finally:
2163            conn.close()
2164
2165        def mock_create():
2166            return creator()
2167
2168        conn = e.connect()
2169        conn.invalidate()
2170        conn.close()
2171
2172        # test that the 'should_wrap_creator' status
2173        # will dynamically switch if the _creator is monkeypatched.
2174
2175        # patch it with a zero-arg form
2176        with patch.object(e.pool, "_creator", mock_create):
2177            conn = e.connect()
2178            conn.invalidate()
2179            conn.close()
2180
2181        conn = e.connect()
2182        conn.close()
2183