1import collections
2import random
3import threading
4import time
5import weakref
6
7import sqlalchemy as tsa
8from sqlalchemy import event
9from sqlalchemy import pool
10from sqlalchemy import select
11from sqlalchemy import testing
12from sqlalchemy.testing import assert_raises
13from sqlalchemy.testing import assert_raises_message
14from sqlalchemy.testing import eq_
15from sqlalchemy.testing import fixtures
16from sqlalchemy.testing import is_
17from sqlalchemy.testing import is_not_
18from sqlalchemy.testing.engines import testing_engine
19from sqlalchemy.testing.mock import ANY
20from sqlalchemy.testing.mock import call
21from sqlalchemy.testing.mock import Mock
22from sqlalchemy.testing.mock import patch
23from sqlalchemy.testing.util import gc_collect
24from sqlalchemy.testing.util import lazy_gc
25
26join_timeout = 10
27
28
29def MockDBAPI():  # noqa
30    def cursor():
31        return Mock()
32
33    def connect(*arg, **kw):
34        def close():
35            conn.closed = True
36
37        # mock seems like it might have an issue logging
38        # call_count correctly under threading, not sure.
39        # adding a side_effect for close seems to help.
40        conn = Mock(
41            cursor=Mock(side_effect=cursor),
42            close=Mock(side_effect=close),
43            closed=False,
44        )
45        return conn
46
47    def shutdown(value):
48        if value:
49            db.connect = Mock(side_effect=Exception("connect failed"))
50        else:
51            db.connect = Mock(side_effect=connect)
52        db.is_shutdown = value
53
54    db = Mock(
55        connect=Mock(side_effect=connect), shutdown=shutdown, is_shutdown=False
56    )
57    return db
58
59
60class PoolTestBase(fixtures.TestBase):
61    def setup(self):
62        pool.clear_managers()
63        self._teardown_conns = []
64
65    def teardown(self):
66        for ref in self._teardown_conns:
67            conn = ref()
68            if conn:
69                conn.close()
70
71    @classmethod
72    def teardown_class(cls):
73        pool.clear_managers()
74
75    def _with_teardown(self, connection):
76        self._teardown_conns.append(weakref.ref(connection))
77        return connection
78
79    def _queuepool_fixture(self, **kw):
80        dbapi, pool = self._queuepool_dbapi_fixture(**kw)
81        return pool
82
83    def _queuepool_dbapi_fixture(self, **kw):
84        dbapi = MockDBAPI()
85        return (
86            dbapi,
87            pool.QueuePool(creator=lambda: dbapi.connect("foo.db"), **kw),
88        )
89
90
91class PoolTest(PoolTestBase):
92    def test_manager(self):
93        manager = pool.manage(MockDBAPI(), use_threadlocal=True)
94
95        c1 = manager.connect("foo.db")
96        c2 = manager.connect("foo.db")
97        c3 = manager.connect("bar.db")
98        c4 = manager.connect("foo.db", bar="bat")
99        c5 = manager.connect("foo.db", bar="hoho")
100        c6 = manager.connect("foo.db", bar="bat")
101
102        assert c1.cursor() is not None
103        assert c1 is c2
104        assert c1 is not c3
105        assert c4 is c6
106        assert c4 is not c5
107
108    def test_manager_with_key(self):
109
110        dbapi = MockDBAPI()
111        manager = pool.manage(dbapi, use_threadlocal=True)
112
113        c1 = manager.connect("foo.db", sa_pool_key="a")
114        c2 = manager.connect("foo.db", sa_pool_key="b")
115        c3 = manager.connect("bar.db", sa_pool_key="a")
116
117        assert c1.cursor() is not None
118        assert c1 is not c2
119        assert c1 is c3
120
121        eq_(dbapi.connect.mock_calls, [call("foo.db"), call("foo.db")])
122
123    def test_bad_args(self):
124        manager = pool.manage(MockDBAPI())
125        manager.connect(None)
126
127    def test_non_thread_local_manager(self):
128        manager = pool.manage(MockDBAPI(), use_threadlocal=False)
129
130        connection = manager.connect("foo.db")
131        connection2 = manager.connect("foo.db")
132
133        self.assert_(connection.cursor() is not None)
134        self.assert_(connection is not connection2)
135
136    @testing.fails_on(
137        "+pyodbc", "pyodbc cursor doesn't implement tuple __eq__"
138    )
139    @testing.fails_on("+pg8000", "returns [1], not (1,)")
140    def test_cursor_iterable(self):
141        conn = testing.db.raw_connection()
142        cursor = conn.cursor()
143        cursor.execute(str(select([1], bind=testing.db)))
144        expected = [(1,)]
145        for row in cursor:
146            eq_(row, expected.pop(0))
147
148    def test_no_connect_on_recreate(self):
149        def creator():
150            raise Exception("no creates allowed")
151
152        for cls in (
153            pool.SingletonThreadPool,
154            pool.StaticPool,
155            pool.QueuePool,
156            pool.NullPool,
157            pool.AssertionPool,
158        ):
159            p = cls(creator=creator)
160            p.dispose()
161            p2 = p.recreate()
162            assert p2.__class__ is cls
163
164            mock_dbapi = MockDBAPI()
165            p = cls(creator=mock_dbapi.connect)
166            conn = p.connect()
167            conn.close()
168            mock_dbapi.connect.side_effect = Exception("error!")
169            p.dispose()
170            p.recreate()
171
172    def test_threadlocal_del(self):
173        self._do_testthreadlocal(useclose=False)
174
175    def test_threadlocal_close(self):
176        self._do_testthreadlocal(useclose=True)
177
178    def _do_testthreadlocal(self, useclose=False):
179        dbapi = MockDBAPI()
180        for p in (
181            pool.QueuePool(
182                creator=dbapi.connect,
183                pool_size=3,
184                max_overflow=-1,
185                use_threadlocal=True,
186            ),
187            pool.SingletonThreadPool(
188                creator=dbapi.connect, use_threadlocal=True
189            ),
190        ):
191            c1 = p.connect()
192            c2 = p.connect()
193            self.assert_(c1 is c2)
194            c3 = p.unique_connection()
195            self.assert_(c3 is not c1)
196            if useclose:
197                c2.close()
198            else:
199                c2 = None
200            c2 = p.connect()
201            self.assert_(c1 is c2)
202            self.assert_(c3 is not c1)
203            if useclose:
204                c2.close()
205            else:
206                c2 = None
207                lazy_gc()
208            if useclose:
209                c1 = p.connect()
210                c2 = p.connect()
211                c3 = p.connect()
212                c3.close()
213                c2.close()
214                self.assert_(c1.connection is not None)
215                c1.close()
216            c1 = c2 = c3 = None
217
218            # extra tests with QueuePool to ensure connections get
219            # __del__()ed when dereferenced
220
221            if isinstance(p, pool.QueuePool):
222                lazy_gc()
223                self.assert_(p.checkedout() == 0)
224                c1 = p.connect()
225                c2 = p.connect()
226                if useclose:
227                    c2.close()
228                    c1.close()
229                else:
230                    c2 = None
231                    c1 = None
232                    lazy_gc()
233                self.assert_(p.checkedout() == 0)
234
235    def test_info(self):
236        p = self._queuepool_fixture(pool_size=1, max_overflow=0)
237
238        c = p.connect()
239        self.assert_(not c.info)
240        self.assert_(c.info is c._connection_record.info)
241
242        c.info["foo"] = "bar"
243        c.close()
244        del c
245
246        c = p.connect()
247        self.assert_("foo" in c.info)
248
249        c.invalidate()
250        c = p.connect()
251        self.assert_("foo" not in c.info)
252
253        c.info["foo2"] = "bar2"
254        c.detach()
255        self.assert_("foo2" in c.info)
256
257        c2 = p.connect()
258        is_not_(c.connection, c2.connection)
259        assert not c2.info
260        assert "foo2" in c.info
261
262    def test_rec_info(self):
263        p = self._queuepool_fixture(pool_size=1, max_overflow=0)
264
265        c = p.connect()
266        self.assert_(not c.record_info)
267        self.assert_(c.record_info is c._connection_record.record_info)
268
269        c.record_info["foo"] = "bar"
270        c.close()
271        del c
272
273        c = p.connect()
274        self.assert_("foo" in c.record_info)
275
276        c.invalidate()
277        c = p.connect()
278        self.assert_("foo" in c.record_info)
279
280        c.record_info["foo2"] = "bar2"
281        c.detach()
282        is_(c.record_info, None)
283        is_(c._connection_record, None)
284
285        c2 = p.connect()
286
287        assert c2.record_info
288        assert "foo2" in c2.record_info
289
290    def test_rec_unconnected(self):
291        # test production of a _ConnectionRecord with an
292        # initially unconnected state.
293
294        dbapi = MockDBAPI()
295        p1 = pool.Pool(creator=lambda: dbapi.connect("foo.db"))
296
297        r1 = pool._ConnectionRecord(p1, connect=False)
298
299        assert not r1.connection
300        c1 = r1.get_connection()
301        is_(c1, r1.connection)
302
303    def test_rec_close_reopen(self):
304        # test that _ConnectionRecord.close() allows
305        # the record to be reusable
306        dbapi = MockDBAPI()
307        p1 = pool.Pool(creator=lambda: dbapi.connect("foo.db"))
308
309        r1 = pool._ConnectionRecord(p1)
310
311        c1 = r1.connection
312        c2 = r1.get_connection()
313        is_(c1, c2)
314
315        r1.close()
316
317        assert not r1.connection
318        eq_(c1.mock_calls, [call.close()])
319
320        c2 = r1.get_connection()
321
322        is_not_(c1, c2)
323        is_(c2, r1.connection)
324
325        eq_(c2.mock_calls, [])
326
327
328class PoolDialectTest(PoolTestBase):
329    def _dialect(self):
330        canary = []
331
332        class PoolDialect(object):
333            def do_rollback(self, dbapi_connection):
334                canary.append("R")
335                dbapi_connection.rollback()
336
337            def do_commit(self, dbapi_connection):
338                canary.append("C")
339                dbapi_connection.commit()
340
341            def do_close(self, dbapi_connection):
342                canary.append("CL")
343                dbapi_connection.close()
344
345        return PoolDialect(), canary
346
347    def _do_test(self, pool_cls, assertion):
348        mock_dbapi = MockDBAPI()
349        dialect, canary = self._dialect()
350
351        p = pool_cls(creator=mock_dbapi.connect)
352        p._dialect = dialect
353        conn = p.connect()
354        conn.close()
355        p.dispose()
356        p.recreate()
357        conn = p.connect()
358        conn.close()
359        eq_(canary, assertion)
360
361    def test_queue_pool(self):
362        self._do_test(pool.QueuePool, ["R", "CL", "R"])
363
364    def test_assertion_pool(self):
365        self._do_test(pool.AssertionPool, ["R", "CL", "R"])
366
367    def test_singleton_pool(self):
368        self._do_test(pool.SingletonThreadPool, ["R", "CL", "R"])
369
370    def test_null_pool(self):
371        self._do_test(pool.NullPool, ["R", "CL", "R", "CL"])
372
373    def test_static_pool(self):
374        self._do_test(pool.StaticPool, ["R", "R"])
375
376
377class PoolEventsTest(PoolTestBase):
378    def _first_connect_event_fixture(self):
379        p = self._queuepool_fixture()
380        canary = []
381
382        def first_connect(*arg, **kw):
383            canary.append("first_connect")
384
385        event.listen(p, "first_connect", first_connect)
386
387        return p, canary
388
389    def _connect_event_fixture(self):
390        p = self._queuepool_fixture()
391        canary = []
392
393        def connect(*arg, **kw):
394            canary.append("connect")
395
396        event.listen(p, "connect", connect)
397
398        return p, canary
399
400    def _checkout_event_fixture(self):
401        p = self._queuepool_fixture()
402        canary = []
403
404        def checkout(*arg, **kw):
405            canary.append("checkout")
406
407        event.listen(p, "checkout", checkout)
408
409        return p, canary
410
411    def _checkin_event_fixture(self):
412        p = self._queuepool_fixture()
413        canary = []
414
415        def checkin(*arg, **kw):
416            canary.append("checkin")
417
418        event.listen(p, "checkin", checkin)
419
420        return p, canary
421
422    def _reset_event_fixture(self):
423        p = self._queuepool_fixture()
424        canary = []
425
426        def reset(*arg, **kw):
427            canary.append("reset")
428
429        event.listen(p, "reset", reset)
430
431        return p, canary
432
433    def _invalidate_event_fixture(self):
434        p = self._queuepool_fixture()
435        canary = Mock()
436        event.listen(p, "invalidate", canary)
437
438        return p, canary
439
440    def _soft_invalidate_event_fixture(self):
441        p = self._queuepool_fixture()
442        canary = Mock()
443        event.listen(p, "soft_invalidate", canary)
444
445        return p, canary
446
447    def _close_event_fixture(self):
448        p = self._queuepool_fixture()
449        canary = Mock()
450        event.listen(p, "close", canary)
451
452        return p, canary
453
454    def _detach_event_fixture(self):
455        p = self._queuepool_fixture()
456        canary = Mock()
457        event.listen(p, "detach", canary)
458
459        return p, canary
460
461    def _close_detached_event_fixture(self):
462        p = self._queuepool_fixture()
463        canary = Mock()
464        event.listen(p, "close_detached", canary)
465
466        return p, canary
467
468    def test_close(self):
469        p, canary = self._close_event_fixture()
470
471        c1 = p.connect()
472
473        connection = c1.connection
474        rec = c1._connection_record
475
476        c1.close()
477
478        eq_(canary.mock_calls, [])
479
480        p.dispose()
481        eq_(canary.mock_calls, [call(connection, rec)])
482
483    def test_detach(self):
484        p, canary = self._detach_event_fixture()
485
486        c1 = p.connect()
487
488        connection = c1.connection
489        rec = c1._connection_record
490
491        c1.detach()
492
493        eq_(canary.mock_calls, [call(connection, rec)])
494
495    def test_detach_close(self):
496        p, canary = self._close_detached_event_fixture()
497
498        c1 = p.connect()
499
500        connection = c1.connection
501
502        c1.detach()
503
504        c1.close()
505        eq_(canary.mock_calls, [call(connection)])
506
507    def test_first_connect_event(self):
508        p, canary = self._first_connect_event_fixture()
509
510        p.connect()
511        eq_(canary, ["first_connect"])
512
513    def test_first_connect_event_fires_once(self):
514        p, canary = self._first_connect_event_fixture()
515
516        p.connect()
517        p.connect()
518
519        eq_(canary, ["first_connect"])
520
521    def test_first_connect_on_previously_recreated(self):
522        p, canary = self._first_connect_event_fixture()
523
524        p2 = p.recreate()
525        p.connect()
526        p2.connect()
527
528        eq_(canary, ["first_connect", "first_connect"])
529
530    def test_first_connect_on_subsequently_recreated(self):
531        p, canary = self._first_connect_event_fixture()
532
533        p.connect()
534        p2 = p.recreate()
535        p2.connect()
536
537        eq_(canary, ["first_connect", "first_connect"])
538
539    def test_connect_event(self):
540        p, canary = self._connect_event_fixture()
541
542        p.connect()
543        eq_(canary, ["connect"])
544
545    def test_connect_event_fires_subsequent(self):
546        p, canary = self._connect_event_fixture()
547
548        c1 = p.connect()  # noqa
549        c2 = p.connect()  # noqa
550
551        eq_(canary, ["connect", "connect"])
552
553    def test_connect_on_previously_recreated(self):
554        p, canary = self._connect_event_fixture()
555
556        p2 = p.recreate()
557
558        p.connect()
559        p2.connect()
560
561        eq_(canary, ["connect", "connect"])
562
563    def test_connect_on_subsequently_recreated(self):
564        p, canary = self._connect_event_fixture()
565
566        p.connect()
567        p2 = p.recreate()
568        p2.connect()
569
570        eq_(canary, ["connect", "connect"])
571
572    def test_checkout_event(self):
573        p, canary = self._checkout_event_fixture()
574
575        p.connect()
576        eq_(canary, ["checkout"])
577
578    def test_checkout_event_fires_subsequent(self):
579        p, canary = self._checkout_event_fixture()
580
581        p.connect()
582        p.connect()
583        eq_(canary, ["checkout", "checkout"])
584
585    def test_checkout_event_on_subsequently_recreated(self):
586        p, canary = self._checkout_event_fixture()
587
588        p.connect()
589        p2 = p.recreate()
590        p2.connect()
591
592        eq_(canary, ["checkout", "checkout"])
593
594    def test_checkin_event(self):
595        p, canary = self._checkin_event_fixture()
596
597        c1 = p.connect()
598        eq_(canary, [])
599        c1.close()
600        eq_(canary, ["checkin"])
601
602    def test_reset_event(self):
603        p, canary = self._reset_event_fixture()
604
605        c1 = p.connect()
606        eq_(canary, [])
607        c1.close()
608        eq_(canary, ["reset"])
609
610    def test_soft_invalidate_event_no_exception(self):
611        p, canary = self._soft_invalidate_event_fixture()
612
613        c1 = p.connect()
614        c1.close()
615        assert not canary.called
616        c1 = p.connect()
617        dbapi_con = c1.connection
618        c1.invalidate(soft=True)
619        assert canary.call_args_list[0][0][0] is dbapi_con
620        assert canary.call_args_list[0][0][2] is None
621
622    def test_soft_invalidate_event_exception(self):
623        p, canary = self._soft_invalidate_event_fixture()
624
625        c1 = p.connect()
626        c1.close()
627        assert not canary.called
628        c1 = p.connect()
629        dbapi_con = c1.connection
630        exc = Exception("hi")
631        c1.invalidate(exc, soft=True)
632        assert canary.call_args_list[0][0][0] is dbapi_con
633        assert canary.call_args_list[0][0][2] is exc
634
635    def test_invalidate_event_no_exception(self):
636        p, canary = self._invalidate_event_fixture()
637
638        c1 = p.connect()
639        c1.close()
640        assert not canary.called
641        c1 = p.connect()
642        dbapi_con = c1.connection
643        c1.invalidate()
644        assert canary.call_args_list[0][0][0] is dbapi_con
645        assert canary.call_args_list[0][0][2] is None
646
647    def test_invalidate_event_exception(self):
648        p, canary = self._invalidate_event_fixture()
649
650        c1 = p.connect()
651        c1.close()
652        assert not canary.called
653        c1 = p.connect()
654        dbapi_con = c1.connection
655        exc = Exception("hi")
656        c1.invalidate(exc)
657        assert canary.call_args_list[0][0][0] is dbapi_con
658        assert canary.call_args_list[0][0][2] is exc
659
660    def test_checkin_event_gc(self):
661        p, canary = self._checkin_event_fixture()
662
663        c1 = p.connect()
664        eq_(canary, [])
665        del c1
666        lazy_gc()
667        eq_(canary, ["checkin"])
668
669    def test_checkin_event_on_subsequently_recreated(self):
670        p, canary = self._checkin_event_fixture()
671
672        c1 = p.connect()
673        p2 = p.recreate()
674        c2 = p2.connect()
675
676        eq_(canary, [])
677
678        c1.close()
679        eq_(canary, ["checkin"])
680
681        c2.close()
682        eq_(canary, ["checkin", "checkin"])
683
684    def test_listen_targets_scope(self):
685        canary = []
686
687        def listen_one(*args):
688            canary.append("listen_one")
689
690        def listen_two(*args):
691            canary.append("listen_two")
692
693        def listen_three(*args):
694            canary.append("listen_three")
695
696        def listen_four(*args):
697            canary.append("listen_four")
698
699        engine = testing_engine(testing.db.url)
700        event.listen(pool.Pool, "connect", listen_one)
701        event.listen(engine.pool, "connect", listen_two)
702        event.listen(engine, "connect", listen_three)
703        event.listen(engine.__class__, "connect", listen_four)
704
705        engine.execute(select([1])).close()
706        eq_(
707            canary, ["listen_one", "listen_four", "listen_two", "listen_three"]
708        )
709
710    def test_listen_targets_per_subclass(self):
711        """test that listen() called on a subclass remains specific to
712        that subclass."""
713
714        canary = []
715
716        def listen_one(*args):
717            canary.append("listen_one")
718
719        def listen_two(*args):
720            canary.append("listen_two")
721
722        def listen_three(*args):
723            canary.append("listen_three")
724
725        event.listen(pool.Pool, "connect", listen_one)
726        event.listen(pool.QueuePool, "connect", listen_two)
727        event.listen(pool.SingletonThreadPool, "connect", listen_three)
728
729        p1 = pool.QueuePool(creator=MockDBAPI().connect)
730        p2 = pool.SingletonThreadPool(creator=MockDBAPI().connect)
731
732        assert listen_one in p1.dispatch.connect
733        assert listen_two in p1.dispatch.connect
734        assert listen_three not in p1.dispatch.connect
735        assert listen_one in p2.dispatch.connect
736        assert listen_two not in p2.dispatch.connect
737        assert listen_three in p2.dispatch.connect
738
739        p1.connect()
740        eq_(canary, ["listen_one", "listen_two"])
741        p2.connect()
742        eq_(canary, ["listen_one", "listen_two", "listen_one", "listen_three"])
743
744    def test_connect_event_fails_invalidates(self):
745        fail = False
746
747        def listen_one(conn, rec):
748            if fail:
749                raise Exception("it failed")
750
751        def listen_two(conn, rec):
752            rec.info["important_flag"] = True
753
754        p1 = pool.QueuePool(
755            creator=MockDBAPI().connect, pool_size=1, max_overflow=0
756        )
757        event.listen(p1, "connect", listen_one)
758        event.listen(p1, "connect", listen_two)
759
760        conn = p1.connect()
761        eq_(conn.info["important_flag"], True)
762        conn.invalidate()
763        conn.close()
764
765        fail = True
766        assert_raises(Exception, p1.connect)
767
768        fail = False
769
770        conn = p1.connect()
771        eq_(conn.info["important_flag"], True)
772        conn.close()
773
774    def teardown(self):
775        # TODO: need to get remove() functionality
776        # going
777        pool.Pool.dispatch._clear()
778
779
780class PoolFirstConnectSyncTest(PoolTestBase):
781    # test [ticket:2964]
782
783    @testing.requires.timing_intensive
784    def test_sync(self):
785        pool = self._queuepool_fixture(pool_size=3, max_overflow=0)
786
787        evt = Mock()
788
789        @event.listens_for(pool, "first_connect")
790        def slow_first_connect(dbapi_con, rec):
791            time.sleep(1)
792            evt.first_connect()
793
794        @event.listens_for(pool, "connect")
795        def on_connect(dbapi_con, rec):
796            evt.connect()
797
798        def checkout():
799            for j in range(2):
800                c1 = pool.connect()
801                time.sleep(0.02)
802                c1.close()
803                time.sleep(0.02)
804
805        threads = []
806        for i in range(5):
807            th = threading.Thread(target=checkout)
808            th.start()
809            threads.append(th)
810        for th in threads:
811            th.join(join_timeout)
812
813        eq_(
814            evt.mock_calls,
815            [
816                call.first_connect(),
817                call.connect(),
818                call.connect(),
819                call.connect(),
820            ],
821        )
822
823
824class DeprecatedPoolListenerTest(PoolTestBase):
825    @testing.requires.predictable_gc
826    @testing.uses_deprecated(
827        r".*Use the PoolEvents",
828        r".*'listeners' argument .* is deprecated"
829    )
830    def test_listeners(self):
831        class InstrumentingListener(object):
832            def __init__(self):
833                if hasattr(self, "connect"):
834                    self.connect = self.inst_connect
835                if hasattr(self, "first_connect"):
836                    self.first_connect = self.inst_first_connect
837                if hasattr(self, "checkout"):
838                    self.checkout = self.inst_checkout
839                if hasattr(self, "checkin"):
840                    self.checkin = self.inst_checkin
841                self.clear()
842
843            def clear(self):
844                self.connected = []
845                self.first_connected = []
846                self.checked_out = []
847                self.checked_in = []
848
849            def assert_total(self, conn, fconn, cout, cin):
850                eq_(len(self.connected), conn)
851                eq_(len(self.first_connected), fconn)
852                eq_(len(self.checked_out), cout)
853                eq_(len(self.checked_in), cin)
854
855            def assert_in(self, item, in_conn, in_fconn, in_cout, in_cin):
856                eq_((item in self.connected), in_conn)
857                eq_((item in self.first_connected), in_fconn)
858                eq_((item in self.checked_out), in_cout)
859                eq_((item in self.checked_in), in_cin)
860
861            def inst_connect(self, con, record):
862                print("connect(%s, %s)" % (con, record))
863                assert con is not None
864                assert record is not None
865                self.connected.append(con)
866
867            def inst_first_connect(self, con, record):
868                print("first_connect(%s, %s)" % (con, record))
869                assert con is not None
870                assert record is not None
871                self.first_connected.append(con)
872
873            def inst_checkout(self, con, record, proxy):
874                print("checkout(%s, %s, %s)" % (con, record, proxy))
875                assert con is not None
876                assert record is not None
877                assert proxy is not None
878                self.checked_out.append(con)
879
880            def inst_checkin(self, con, record):
881                print("checkin(%s, %s)" % (con, record))
882                # con can be None if invalidated
883                assert record is not None
884                self.checked_in.append(con)
885
886        class ListenAll(tsa.interfaces.PoolListener, InstrumentingListener):
887            pass
888
889        class ListenConnect(InstrumentingListener):
890            def connect(self, con, record):
891                pass
892
893        class ListenFirstConnect(InstrumentingListener):
894            def first_connect(self, con, record):
895                pass
896
897        class ListenCheckOut(InstrumentingListener):
898            def checkout(self, con, record, proxy, num):
899                pass
900
901        class ListenCheckIn(InstrumentingListener):
902            def checkin(self, con, record):
903                pass
904
905        def assert_listeners(p, total, conn, fconn, cout, cin):
906            for instance in (p, p.recreate()):
907                self.assert_(len(instance.dispatch.connect) == conn)
908                self.assert_(len(instance.dispatch.first_connect) == fconn)
909                self.assert_(len(instance.dispatch.checkout) == cout)
910                self.assert_(len(instance.dispatch.checkin) == cin)
911
912        p = self._queuepool_fixture()
913        assert_listeners(p, 0, 0, 0, 0, 0)
914
915        p.add_listener(ListenAll())
916        assert_listeners(p, 1, 1, 1, 1, 1)
917
918        p.add_listener(ListenConnect())
919        assert_listeners(p, 2, 2, 1, 1, 1)
920
921        p.add_listener(ListenFirstConnect())
922        assert_listeners(p, 3, 2, 2, 1, 1)
923
924        p.add_listener(ListenCheckOut())
925        assert_listeners(p, 4, 2, 2, 2, 1)
926
927        p.add_listener(ListenCheckIn())
928        assert_listeners(p, 5, 2, 2, 2, 2)
929        del p
930
931        snoop = ListenAll()
932        p = self._queuepool_fixture(listeners=[snoop])
933        assert_listeners(p, 1, 1, 1, 1, 1)
934
935        c = p.connect()
936        snoop.assert_total(1, 1, 1, 0)
937        cc = c.connection
938        snoop.assert_in(cc, True, True, True, False)
939        c.close()
940        snoop.assert_in(cc, True, True, True, True)
941        del c, cc
942
943        snoop.clear()
944
945        # this one depends on immediate gc
946        c = p.connect()
947        cc = c.connection
948        snoop.assert_in(cc, False, False, True, False)
949        snoop.assert_total(0, 0, 1, 0)
950        del c, cc
951        lazy_gc()
952        snoop.assert_total(0, 0, 1, 1)
953
954        p.dispose()
955        snoop.clear()
956
957        c = p.connect()
958        c.close()
959        c = p.connect()
960        snoop.assert_total(1, 0, 2, 1)
961        c.close()
962        snoop.assert_total(1, 0, 2, 2)
963
964        # invalidation
965        p.dispose()
966        snoop.clear()
967
968        c = p.connect()
969        snoop.assert_total(1, 0, 1, 0)
970        c.invalidate()
971        snoop.assert_total(1, 0, 1, 1)
972        c.close()
973        snoop.assert_total(1, 0, 1, 1)
974        del c
975        lazy_gc()
976        snoop.assert_total(1, 0, 1, 1)
977        c = p.connect()
978        snoop.assert_total(2, 0, 2, 1)
979        c.close()
980        del c
981        lazy_gc()
982        snoop.assert_total(2, 0, 2, 2)
983
984        # detached
985        p.dispose()
986        snoop.clear()
987
988        c = p.connect()
989        snoop.assert_total(1, 0, 1, 0)
990        c.detach()
991        snoop.assert_total(1, 0, 1, 0)
992        c.close()
993        del c
994        snoop.assert_total(1, 0, 1, 0)
995        c = p.connect()
996        snoop.assert_total(2, 0, 2, 0)
997        c.close()
998        del c
999        snoop.assert_total(2, 0, 2, 1)
1000
1001        # recreated
1002        p = p.recreate()
1003        snoop.clear()
1004
1005        c = p.connect()
1006        snoop.assert_total(1, 1, 1, 0)
1007        c.close()
1008        snoop.assert_total(1, 1, 1, 1)
1009        c = p.connect()
1010        snoop.assert_total(1, 1, 2, 1)
1011        c.close()
1012        snoop.assert_total(1, 1, 2, 2)
1013
1014    @testing.uses_deprecated(
1015        r".*Use the PoolEvents",
1016        r".*'listeners' argument .* is deprecated"
1017    )
1018    def test_listeners_callables(self):
1019        def connect(dbapi_con, con_record):
1020            counts[0] += 1
1021
1022        def checkout(dbapi_con, con_record, con_proxy):
1023            counts[1] += 1
1024
1025        def checkin(dbapi_con, con_record):
1026            counts[2] += 1
1027
1028        i_all = dict(connect=connect, checkout=checkout, checkin=checkin)
1029        i_connect = dict(connect=connect)
1030        i_checkout = dict(checkout=checkout)
1031        i_checkin = dict(checkin=checkin)
1032
1033        for cls in (pool.QueuePool, pool.StaticPool):
1034            counts = [0, 0, 0]
1035
1036            def assert_listeners(p, total, conn, cout, cin):
1037                for instance in (p, p.recreate()):
1038                    eq_(len(instance.dispatch.connect), conn)
1039                    eq_(len(instance.dispatch.checkout), cout)
1040                    eq_(len(instance.dispatch.checkin), cin)
1041
1042            p = self._queuepool_fixture()
1043            assert_listeners(p, 0, 0, 0, 0)
1044
1045            p.add_listener(i_all)
1046            assert_listeners(p, 1, 1, 1, 1)
1047
1048            p.add_listener(i_connect)
1049            assert_listeners(p, 2, 1, 1, 1)
1050
1051            p.add_listener(i_checkout)
1052            assert_listeners(p, 3, 1, 1, 1)
1053
1054            p.add_listener(i_checkin)
1055            assert_listeners(p, 4, 1, 1, 1)
1056            del p
1057
1058            p = self._queuepool_fixture(listeners=[i_all])
1059            assert_listeners(p, 1, 1, 1, 1)
1060
1061            c = p.connect()
1062            assert counts == [1, 1, 0]
1063            c.close()
1064            assert counts == [1, 1, 1]
1065
1066            c = p.connect()
1067            assert counts == [1, 2, 1]
1068            p.add_listener(i_checkin)
1069            c.close()
1070            assert counts == [1, 2, 2]
1071
1072
1073class QueuePoolTest(PoolTestBase):
1074    def test_queuepool_del(self):
1075        self._do_testqueuepool(useclose=False)
1076
1077    def test_queuepool_close(self):
1078        self._do_testqueuepool(useclose=True)
1079
1080    def _do_testqueuepool(self, useclose=False):
1081        p = self._queuepool_fixture(pool_size=3, max_overflow=-1)
1082
1083        def status(pool):
1084            return (
1085                pool.size(),
1086                pool.checkedin(),
1087                pool.overflow(),
1088                pool.checkedout(),
1089            )
1090
1091        c1 = p.connect()
1092        self.assert_(status(p) == (3, 0, -2, 1))
1093        c2 = p.connect()
1094        self.assert_(status(p) == (3, 0, -1, 2))
1095        c3 = p.connect()
1096        self.assert_(status(p) == (3, 0, 0, 3))
1097        c4 = p.connect()
1098        self.assert_(status(p) == (3, 0, 1, 4))
1099        c5 = p.connect()
1100        self.assert_(status(p) == (3, 0, 2, 5))
1101        c6 = p.connect()
1102        self.assert_(status(p) == (3, 0, 3, 6))
1103        if useclose:
1104            c4.close()
1105            c3.close()
1106            c2.close()
1107        else:
1108            c4 = c3 = c2 = None
1109            lazy_gc()
1110        self.assert_(status(p) == (3, 3, 3, 3))
1111        if useclose:
1112            c1.close()
1113            c5.close()
1114            c6.close()
1115        else:
1116            c1 = c5 = c6 = None
1117            lazy_gc()
1118        self.assert_(status(p) == (3, 3, 0, 0))
1119        c1 = p.connect()
1120        c2 = p.connect()
1121        self.assert_(status(p) == (3, 1, 0, 2), status(p))
1122        if useclose:
1123            c2.close()
1124        else:
1125            c2 = None
1126            lazy_gc()
1127        self.assert_(status(p) == (3, 2, 0, 1))
1128        c1.close()
1129        lazy_gc()
1130        assert not pool._refs
1131
1132    @testing.requires.timing_intensive
1133    def test_timeout(self):
1134        p = self._queuepool_fixture(pool_size=3, max_overflow=0, timeout=2)
1135        c1 = p.connect()  # noqa
1136        c2 = p.connect()  # noqa
1137        c3 = p.connect()  # noqa
1138        now = time.time()
1139
1140        assert_raises(tsa.exc.TimeoutError, p.connect)
1141        assert int(time.time() - now) == 2
1142
1143    @testing.requires.threading_with_mock
1144    @testing.requires.timing_intensive
1145    def test_timeout_race(self):
1146        # test a race condition where the initial connecting threads all race
1147        # to queue.Empty, then block on the mutex.  each thread consumes a
1148        # connection as they go in.  when the limit is reached, the remaining
1149        # threads go in, and get TimeoutError; even though they never got to
1150        # wait for the timeout on queue.get().  the fix involves checking the
1151        # timeout again within the mutex, and if so, unlocking and throwing
1152        # them back to the start of do_get()
1153        dbapi = MockDBAPI()
1154        p = pool.QueuePool(
1155            creator=lambda: dbapi.connect(delay=0.05),
1156            pool_size=2,
1157            max_overflow=1,
1158            use_threadlocal=False,
1159            timeout=3,
1160        )
1161        timeouts = []
1162
1163        def checkout():
1164            for x in range(1):
1165                now = time.time()
1166                try:
1167                    c1 = p.connect()
1168                except tsa.exc.TimeoutError:
1169                    timeouts.append(time.time() - now)
1170                    continue
1171                time.sleep(4)
1172                c1.close()
1173
1174        threads = []
1175        for i in range(10):
1176            th = threading.Thread(target=checkout)
1177            th.start()
1178            threads.append(th)
1179        for th in threads:
1180            th.join(join_timeout)
1181
1182        assert len(timeouts) > 0
1183        for t in timeouts:
1184            assert t >= 3, "Not all timeouts were >= 3 seconds %r" % timeouts
1185            # normally, the timeout should under 4 seconds,
1186            # but on a loaded down buildbot it can go up.
1187            assert t < 14, "Not all timeouts were < 14 seconds %r" % timeouts
1188
1189    def _test_overflow(self, thread_count, max_overflow):
1190        gc_collect()
1191
1192        dbapi = MockDBAPI()
1193        mutex = threading.Lock()
1194
1195        def creator():
1196            time.sleep(0.05)
1197            with mutex:
1198                return dbapi.connect()
1199
1200        p = pool.QueuePool(
1201            creator=creator, pool_size=3, timeout=2, max_overflow=max_overflow
1202        )
1203        peaks = []
1204
1205        def whammy():
1206            for i in range(10):
1207                try:
1208                    con = p.connect()
1209                    time.sleep(0.005)
1210                    peaks.append(p.overflow())
1211                    con.close()
1212                    del con
1213                except tsa.exc.TimeoutError:
1214                    pass
1215
1216        threads = []
1217        for i in range(thread_count):
1218            th = threading.Thread(target=whammy)
1219            th.start()
1220            threads.append(th)
1221        for th in threads:
1222            th.join(join_timeout)
1223
1224        self.assert_(max(peaks) <= max_overflow)
1225
1226        lazy_gc()
1227        assert not pool._refs
1228
1229    def test_overflow_reset_on_failed_connect(self):
1230        dbapi = Mock()
1231
1232        def failing_dbapi():
1233            time.sleep(2)
1234            raise Exception("connection failed")
1235
1236        creator = dbapi.connect
1237
1238        def create():
1239            return creator()
1240
1241        p = pool.QueuePool(creator=create, pool_size=2, max_overflow=3)
1242        c1 = self._with_teardown(p.connect())  # noqa
1243        c2 = self._with_teardown(p.connect())  # noqa
1244        c3 = self._with_teardown(p.connect())  # noqa
1245        eq_(p._overflow, 1)
1246        creator = failing_dbapi
1247        assert_raises(Exception, p.connect)
1248        eq_(p._overflow, 1)
1249
1250    @testing.requires.threading_with_mock
1251    @testing.requires.timing_intensive
1252    def test_hanging_connect_within_overflow(self):
1253        """test that a single connect() call which is hanging
1254        does not block other connections from proceeding."""
1255
1256        dbapi = Mock()
1257        mutex = threading.Lock()
1258
1259        def hanging_dbapi():
1260            time.sleep(2)
1261            with mutex:
1262                return dbapi.connect()
1263
1264        def fast_dbapi():
1265            with mutex:
1266                return dbapi.connect()
1267
1268        creator = threading.local()
1269
1270        def create():
1271            return creator.mock_connector()
1272
1273        def run_test(name, pool, should_hang):
1274            if should_hang:
1275                creator.mock_connector = hanging_dbapi
1276            else:
1277                creator.mock_connector = fast_dbapi
1278
1279            conn = pool.connect()
1280            conn.operation(name)
1281            time.sleep(1)
1282            conn.close()
1283
1284        p = pool.QueuePool(creator=create, pool_size=2, max_overflow=3)
1285
1286        threads = [
1287            threading.Thread(target=run_test, args=("success_one", p, False)),
1288            threading.Thread(target=run_test, args=("success_two", p, False)),
1289            threading.Thread(target=run_test, args=("overflow_one", p, True)),
1290            threading.Thread(target=run_test, args=("overflow_two", p, False)),
1291            threading.Thread(
1292                target=run_test, args=("overflow_three", p, False)
1293            ),
1294        ]
1295        for t in threads:
1296            t.start()
1297            time.sleep(0.2)
1298
1299        for t in threads:
1300            t.join(timeout=join_timeout)
1301        eq_(
1302            dbapi.connect().operation.mock_calls,
1303            [
1304                call("success_one"),
1305                call("success_two"),
1306                call("overflow_two"),
1307                call("overflow_three"),
1308                call("overflow_one"),
1309            ],
1310        )
1311
1312    @testing.requires.threading_with_mock
1313    @testing.requires.timing_intensive
1314    def test_waiters_handled(self):
1315        """test that threads waiting for connections are
1316        handled when the pool is replaced.
1317
1318        """
1319        mutex = threading.Lock()
1320        dbapi = MockDBAPI()
1321
1322        def creator():
1323            mutex.acquire()
1324            try:
1325                return dbapi.connect()
1326            finally:
1327                mutex.release()
1328
1329        success = []
1330        for timeout in (None, 30):
1331            for max_overflow in (0, -1, 3):
1332                p = pool.QueuePool(
1333                    creator=creator,
1334                    pool_size=2,
1335                    timeout=timeout,
1336                    max_overflow=max_overflow,
1337                )
1338
1339                def waiter(p, timeout, max_overflow):
1340                    success_key = (timeout, max_overflow)
1341                    conn = p.connect()
1342                    success.append(success_key)
1343                    time.sleep(0.1)
1344                    conn.close()
1345
1346                c1 = p.connect()  # noqa
1347                c2 = p.connect()
1348
1349                threads = []
1350                for i in range(2):
1351                    t = threading.Thread(
1352                        target=waiter, args=(p, timeout, max_overflow)
1353                    )
1354                    t.daemon = True
1355                    t.start()
1356                    threads.append(t)
1357
1358                # this sleep makes sure that the
1359                # two waiter threads hit upon wait()
1360                # inside the queue, before we invalidate the other
1361                # two conns
1362                time.sleep(0.2)
1363                p._invalidate(c2)
1364
1365                for t in threads:
1366                    t.join(join_timeout)
1367
1368        eq_(len(success), 12, "successes: %s" % success)
1369
1370    def test_connrec_invalidated_within_checkout_no_race(self):
1371        """Test that a concurrent ConnectionRecord.invalidate() which
1372        occurs after the ConnectionFairy has called
1373        _ConnectionRecord.checkout()
1374        but before the ConnectionFairy tests "fairy.connection is None"
1375        will not result in an InvalidRequestError.
1376
1377        This use case assumes that a listener on the checkout() event
1378        will be raising DisconnectionError so that a reconnect attempt
1379        may occur.
1380
1381        """
1382        dbapi = MockDBAPI()
1383
1384        def creator():
1385            return dbapi.connect()
1386
1387        p = pool.QueuePool(creator=creator, pool_size=1, max_overflow=0)
1388
1389        conn = p.connect()
1390        conn.close()
1391
1392        _existing_checkout = pool._ConnectionRecord.checkout
1393
1394        @classmethod
1395        def _decorate_existing_checkout(cls, *arg, **kw):
1396            fairy = _existing_checkout(*arg, **kw)
1397            connrec = fairy._connection_record
1398            connrec.invalidate()
1399            return fairy
1400
1401        with patch(
1402            "sqlalchemy.pool._ConnectionRecord.checkout",
1403            _decorate_existing_checkout,
1404        ):
1405            conn = p.connect()
1406            is_(conn._connection_record.connection, None)
1407        conn.close()
1408
1409    @testing.requires.threading_with_mock
1410    @testing.requires.timing_intensive
1411    def test_notify_waiters(self):
1412        dbapi = MockDBAPI()
1413
1414        canary = []
1415
1416        def creator():
1417            canary.append(1)
1418            return dbapi.connect()
1419
1420        p1 = pool.QueuePool(
1421            creator=creator, pool_size=1, timeout=None, max_overflow=0
1422        )
1423
1424        def waiter(p):
1425            conn = p.connect()
1426            canary.append(2)
1427            time.sleep(0.5)
1428            conn.close()
1429
1430        c1 = p1.connect()
1431
1432        threads = []
1433        for i in range(5):
1434            t = threading.Thread(target=waiter, args=(p1,))
1435            t.start()
1436            threads.append(t)
1437        time.sleep(0.5)
1438        eq_(canary, [1])
1439
1440        # this also calls invalidate()
1441        # on c1
1442        p1._invalidate(c1)
1443
1444        for t in threads:
1445            t.join(join_timeout)
1446
1447        eq_(canary, [1, 1, 2, 2, 2, 2, 2])
1448
1449    def test_dispose_closes_pooled(self):
1450        dbapi = MockDBAPI()
1451
1452        p = pool.QueuePool(
1453            creator=dbapi.connect, pool_size=2, timeout=None, max_overflow=0
1454        )
1455        c1 = p.connect()
1456        c2 = p.connect()
1457        c1_con = c1.connection
1458        c2_con = c2.connection
1459
1460        c1.close()
1461
1462        eq_(c1_con.close.call_count, 0)
1463        eq_(c2_con.close.call_count, 0)
1464
1465        p.dispose()
1466
1467        eq_(c1_con.close.call_count, 1)
1468        eq_(c2_con.close.call_count, 0)
1469
1470        # currently, if a ConnectionFairy is closed
1471        # after the pool has been disposed, there's no
1472        # flag that states it should be invalidated
1473        # immediately - it just gets returned to the
1474        # pool normally...
1475        c2.close()
1476        eq_(c1_con.close.call_count, 1)
1477        eq_(c2_con.close.call_count, 0)
1478
1479        # ...and that's the one we'll get back next.
1480        c3 = p.connect()
1481        assert c3.connection is c2_con
1482
1483    @testing.requires.threading_with_mock
1484    @testing.requires.timing_intensive
1485    def test_no_overflow(self):
1486        self._test_overflow(40, 0)
1487
1488    @testing.requires.threading_with_mock
1489    @testing.requires.timing_intensive
1490    def test_max_overflow(self):
1491        self._test_overflow(40, 5)
1492
1493    def test_mixed_close(self):
1494        pool._refs.clear()
1495        p = self._queuepool_fixture(
1496            pool_size=3, max_overflow=-1, use_threadlocal=True
1497        )
1498        c1 = p.connect()
1499        c2 = p.connect()
1500        assert c1 is c2
1501        c1.close()
1502        c2 = None
1503        assert p.checkedout() == 1
1504        c1 = None
1505        lazy_gc()
1506        assert p.checkedout() == 0
1507        lazy_gc()
1508        assert not pool._refs
1509
1510    def test_overflow_no_gc_tlocal(self):
1511        self._test_overflow_no_gc(True)
1512
1513    def test_overflow_no_gc(self):
1514        self._test_overflow_no_gc(False)
1515
1516    def _test_overflow_no_gc(self, threadlocal):
1517        p = self._queuepool_fixture(pool_size=2, max_overflow=2)
1518
1519        # disable weakref collection of the
1520        # underlying connections
1521        strong_refs = set()
1522
1523        def _conn():
1524            c = p.connect()
1525            strong_refs.add(c.connection)
1526            return c
1527
1528        for j in range(5):
1529            # open 4 conns at a time.  each time this
1530            # will yield two pooled connections + two
1531            # overflow connections.
1532            conns = [_conn() for i in range(4)]
1533            for c in conns:
1534                c.close()
1535
1536        # doing that for a total of 5 times yields
1537        # ten overflow connections closed plus the
1538        # two pooled connections unclosed.
1539
1540        eq_(
1541            set([c.close.call_count for c in strong_refs]),
1542            set([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0]),
1543        )
1544
1545    @testing.requires.predictable_gc
1546    def test_weakref_kaboom(self):
1547        p = self._queuepool_fixture(
1548            pool_size=3, max_overflow=-1, use_threadlocal=True
1549        )
1550        c1 = p.connect()
1551        c2 = p.connect()
1552        c1.close()
1553        c2 = None
1554        del c1
1555        del c2
1556        gc_collect()
1557        assert p.checkedout() == 0
1558        c3 = p.connect()
1559        assert c3 is not None
1560
1561    def test_trick_the_counter(self):
1562        """this is a "flaw" in the connection pool; since threadlocal
1563        uses a single ConnectionFairy per thread with an open/close
1564        counter, you can fool the counter into giving you a
1565        ConnectionFairy with an ambiguous counter.  i.e. its not true
1566        reference counting."""
1567
1568        p = self._queuepool_fixture(
1569            pool_size=3, max_overflow=-1, use_threadlocal=True
1570        )
1571        c1 = p.connect()
1572        c2 = p.connect()
1573        assert c1 is c2
1574        c1.close()
1575        c2 = p.connect()
1576        c2.close()
1577        self.assert_(p.checkedout() != 0)
1578        c2.close()
1579        self.assert_(p.checkedout() == 0)
1580
1581    def test_recycle(self):
1582        with patch("sqlalchemy.pool.time.time") as mock:
1583            mock.return_value = 10000
1584
1585            p = self._queuepool_fixture(
1586                pool_size=1, max_overflow=0, recycle=30
1587            )
1588            c1 = p.connect()
1589            c_ref = weakref.ref(c1.connection)
1590            c1.close()
1591            mock.return_value = 10001
1592            c2 = p.connect()
1593
1594            is_(c2.connection, c_ref())
1595            c2.close()
1596
1597            mock.return_value = 10035
1598            c3 = p.connect()
1599            is_not_(c3.connection, c_ref())
1600
1601    @testing.requires.timing_intensive
1602    def test_recycle_on_invalidate(self):
1603        p = self._queuepool_fixture(pool_size=1, max_overflow=0)
1604        c1 = p.connect()
1605        c_ref = weakref.ref(c1.connection)
1606        c1.close()
1607        c2 = p.connect()
1608        is_(c2.connection, c_ref())
1609
1610        c2_rec = c2._connection_record
1611        p._invalidate(c2)
1612        assert c2_rec.connection is None
1613        c2.close()
1614        time.sleep(0.5)
1615        c3 = p.connect()
1616
1617        is_not_(c3.connection, c_ref())
1618
1619    @testing.requires.timing_intensive
1620    def test_recycle_on_soft_invalidate(self):
1621        p = self._queuepool_fixture(pool_size=1, max_overflow=0)
1622        c1 = p.connect()
1623        c_ref = weakref.ref(c1.connection)
1624        c1.close()
1625        c2 = p.connect()
1626        is_(c2.connection, c_ref())
1627
1628        c2_rec = c2._connection_record
1629        c2.invalidate(soft=True)
1630        is_(c2_rec.connection, c2.connection)
1631
1632        c2.close()
1633        time.sleep(0.5)
1634        c3 = p.connect()
1635        is_not_(c3.connection, c_ref())
1636        is_(c3._connection_record, c2_rec)
1637        is_(c2_rec.connection, c3.connection)
1638
1639    def _no_wr_finalize(self):
1640        finalize_fairy = pool._finalize_fairy
1641
1642        def assert_no_wr_callback(
1643            connection, connection_record, pool, ref, echo, fairy=None
1644        ):
1645            if fairy is None:
1646                raise AssertionError(
1647                    "finalize fairy was called as a weakref callback"
1648                )
1649            return finalize_fairy(
1650                connection, connection_record, pool, ref, echo, fairy
1651            )
1652
1653        return patch.object(pool, "_finalize_fairy", assert_no_wr_callback)
1654
1655    def _assert_cleanup_on_pooled_reconnect(self, dbapi, p):
1656        # p is QueuePool with size=1, max_overflow=2,
1657        # and one connection in the pool that will need to
1658        # reconnect when next used (either due to recycle or invalidate)
1659
1660        with self._no_wr_finalize():
1661            eq_(p.checkedout(), 0)
1662            eq_(p._overflow, 0)
1663            dbapi.shutdown(True)
1664            assert_raises(Exception, p.connect)
1665            eq_(p._overflow, 0)
1666            eq_(p.checkedout(), 0)  # and not 1
1667
1668            dbapi.shutdown(False)
1669
1670            c1 = self._with_teardown(p.connect())  # noqa
1671            assert p._pool.empty()  # poolsize is one, so we're empty OK
1672            c2 = self._with_teardown(p.connect())  # noqa
1673            eq_(p._overflow, 1)  # and not 2
1674
1675            # this hangs if p._overflow is 2
1676            c3 = self._with_teardown(p.connect())
1677
1678            c3.close()
1679
1680    def test_error_on_pooled_reconnect_cleanup_invalidate(self):
1681        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=2)
1682        c1 = p.connect()
1683        c1.invalidate()
1684        c1.close()
1685        self._assert_cleanup_on_pooled_reconnect(dbapi, p)
1686
1687    @testing.requires.timing_intensive
1688    def test_error_on_pooled_reconnect_cleanup_recycle(self):
1689        dbapi, p = self._queuepool_dbapi_fixture(
1690            pool_size=1, max_overflow=2, recycle=1
1691        )
1692        c1 = p.connect()
1693        c1.close()
1694        time.sleep(1.5)
1695        self._assert_cleanup_on_pooled_reconnect(dbapi, p)
1696
1697    def test_connect_handler_not_called_for_recycled(self):
1698        """test [ticket:3497]"""
1699
1700        dbapi, p = self._queuepool_dbapi_fixture(pool_size=2, max_overflow=2)
1701
1702        canary = Mock()
1703
1704        c1 = p.connect()
1705        c2 = p.connect()
1706
1707        c1.close()
1708        c2.close()
1709
1710        dbapi.shutdown(True)
1711
1712        bad = p.connect()
1713        p._invalidate(bad)
1714        bad.close()
1715        assert p._invalidate_time
1716
1717        event.listen(p, "connect", canary.connect)
1718        event.listen(p, "checkout", canary.checkout)
1719
1720        assert_raises(Exception, p.connect)
1721
1722        p._pool.queue = collections.deque(
1723            [c for c in p._pool.queue if c.connection is not None]
1724        )
1725
1726        dbapi.shutdown(False)
1727        c = p.connect()
1728        c.close()
1729
1730        eq_(
1731            canary.mock_calls,
1732            [call.connect(ANY, ANY), call.checkout(ANY, ANY, ANY)],
1733        )
1734
1735    def test_connect_checkout_handler_always_gets_info(self):
1736        """test [ticket:3497]"""
1737
1738        dbapi, p = self._queuepool_dbapi_fixture(pool_size=2, max_overflow=2)
1739
1740        c1 = p.connect()
1741        c2 = p.connect()
1742
1743        c1.close()
1744        c2.close()
1745
1746        dbapi.shutdown(True)
1747
1748        bad = p.connect()
1749        p._invalidate(bad)
1750        bad.close()
1751        assert p._invalidate_time
1752
1753        @event.listens_for(p, "connect")
1754        def connect(conn, conn_rec):
1755            conn_rec.info["x"] = True
1756
1757        @event.listens_for(p, "checkout")
1758        def checkout(conn, conn_rec, conn_f):
1759            assert "x" in conn_rec.info
1760
1761        assert_raises(Exception, p.connect)
1762
1763        p._pool.queue = collections.deque(
1764            [c for c in p._pool.queue if c.connection is not None]
1765        )
1766
1767        dbapi.shutdown(False)
1768        c = p.connect()
1769        c.close()
1770
1771    def test_error_on_pooled_reconnect_cleanup_wcheckout_event(self):
1772        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=2)
1773
1774        c1 = p.connect()
1775        c1.close()
1776
1777        @event.listens_for(p, "checkout")
1778        def handle_checkout_event(dbapi_con, con_record, con_proxy):
1779            if dbapi.is_shutdown:
1780                raise tsa.exc.DisconnectionError()
1781
1782        self._assert_cleanup_on_pooled_reconnect(dbapi, p)
1783
1784    @testing.requires.predictable_gc
1785    def test_userspace_disconnectionerror_weakref_finalizer(self):
1786        dbapi, pool = self._queuepool_dbapi_fixture(
1787            pool_size=1, max_overflow=2
1788        )
1789
1790        @event.listens_for(pool, "checkout")
1791        def handle_checkout_event(dbapi_con, con_record, con_proxy):
1792            if getattr(dbapi_con, "boom") == "yes":
1793                raise tsa.exc.DisconnectionError()
1794
1795        conn = pool.connect()
1796        old_dbapi_conn = conn.connection
1797        conn.close()
1798
1799        eq_(old_dbapi_conn.mock_calls, [call.rollback()])
1800
1801        old_dbapi_conn.boom = "yes"
1802
1803        conn = pool.connect()
1804        dbapi_conn = conn.connection
1805        del conn
1806        gc_collect()
1807
1808        # new connection was reset on return appropriately
1809        eq_(dbapi_conn.mock_calls, [call.rollback()])
1810
1811        # old connection was just closed - did not get an
1812        # erroneous reset on return
1813        eq_(old_dbapi_conn.mock_calls, [call.rollback(), call.close()])
1814
1815    @testing.requires.timing_intensive
1816    def test_recycle_pool_no_race(self):
1817        def slow_close():
1818            slow_closing_connection._slow_close()
1819            time.sleep(0.5)
1820
1821        slow_closing_connection = Mock()
1822        slow_closing_connection.connect.return_value.close = slow_close
1823
1824        class Error(Exception):
1825            pass
1826
1827        dialect = Mock()
1828        dialect.is_disconnect = lambda *arg, **kw: True
1829        dialect.dbapi.Error = Error
1830
1831        pools = []
1832
1833        class TrackQueuePool(pool.QueuePool):
1834            def __init__(self, *arg, **kw):
1835                pools.append(self)
1836                super(TrackQueuePool, self).__init__(*arg, **kw)
1837
1838        def creator():
1839            return slow_closing_connection.connect()
1840
1841        p1 = TrackQueuePool(creator=creator, pool_size=20)
1842
1843        from sqlalchemy import create_engine
1844
1845        eng = create_engine(testing.db.url, pool=p1, _initialize=False)
1846        eng.dialect = dialect
1847
1848        # 15 total connections
1849        conns = [eng.connect() for i in range(15)]
1850
1851        # return 8 back to the pool
1852        for conn in conns[3:10]:
1853            conn.close()
1854
1855        def attempt(conn):
1856            time.sleep(random.random())
1857            try:
1858                conn._handle_dbapi_exception(
1859                    Error(), "statement", {}, Mock(), Mock()
1860                )
1861            except tsa.exc.DBAPIError:
1862                pass
1863
1864        # run an error + invalidate operation on the remaining 7 open
1865        # connections
1866        threads = []
1867        for conn in conns:
1868            t = threading.Thread(target=attempt, args=(conn,))
1869            t.start()
1870            threads.append(t)
1871
1872        for t in threads:
1873            t.join()
1874
1875        # return all 15 connections to the pool
1876        for conn in conns:
1877            conn.close()
1878
1879        # re-open 15 total connections
1880        conns = [eng.connect() for i in range(15)]
1881
1882        # 15 connections have been fully closed due to invalidate
1883        assert slow_closing_connection._slow_close.call_count == 15
1884
1885        # 15 initial connections + 15 reconnections
1886        assert slow_closing_connection.connect.call_count == 30
1887        assert len(pools) <= 2, len(pools)
1888
1889    def test_invalidate(self):
1890        p = self._queuepool_fixture(pool_size=1, max_overflow=0)
1891        c1 = p.connect()
1892        c_id = c1.connection.id
1893        c1.close()
1894        c1 = None
1895        c1 = p.connect()
1896        assert c1.connection.id == c_id
1897        c1.invalidate()
1898        c1 = None
1899        c1 = p.connect()
1900        assert c1.connection.id != c_id
1901
1902    def test_recreate(self):
1903        p = self._queuepool_fixture(
1904            reset_on_return=None, pool_size=1, max_overflow=0
1905        )
1906        p2 = p.recreate()
1907        assert p2.size() == 1
1908        assert p2._reset_on_return is pool.reset_none
1909        assert p2._use_threadlocal is False
1910        assert p2._max_overflow == 0
1911
1912    def test_reconnect(self):
1913        """tests reconnect operations at the pool level.  SA's
1914        engine/dialect includes another layer of reconnect support for
1915        'database was lost' errors."""
1916
1917        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
1918        c1 = p.connect()
1919        c_id = c1.connection.id
1920        c1.close()
1921        c1 = None
1922        c1 = p.connect()
1923        assert c1.connection.id == c_id
1924        dbapi.raise_error = True
1925        c1.invalidate()
1926        c1 = None
1927        c1 = p.connect()
1928        assert c1.connection.id != c_id
1929
1930    def test_detach(self):
1931        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
1932        c1 = p.connect()
1933        c1.detach()
1934        c2 = p.connect()  # noqa
1935        eq_(dbapi.connect.mock_calls, [call("foo.db"), call("foo.db")])
1936
1937        c1_con = c1.connection
1938        assert c1_con is not None
1939        eq_(c1_con.close.call_count, 0)
1940        c1.close()
1941        eq_(c1_con.close.call_count, 1)
1942
1943    def test_detach_via_invalidate(self):
1944        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
1945
1946        c1 = p.connect()
1947        c1_con = c1.connection
1948        c1.invalidate()
1949        assert c1.connection is None
1950        eq_(c1_con.close.call_count, 1)
1951
1952        c2 = p.connect()
1953        assert c2.connection is not c1_con
1954        c2_con = c2.connection
1955
1956        c2.close()
1957        eq_(c2_con.close.call_count, 0)
1958
1959    def test_threadfairy(self):
1960        p = self._queuepool_fixture(
1961            pool_size=3, max_overflow=-1, use_threadlocal=True
1962        )
1963        c1 = p.connect()
1964        c1.close()
1965        c2 = p.connect()
1966        assert c2.connection is not None
1967
1968    def test_no_double_checkin(self):
1969        p = self._queuepool_fixture(pool_size=1)
1970
1971        c1 = p.connect()
1972        rec = c1._connection_record
1973        c1.close()
1974        assert_raises_message(
1975            Warning, "Double checkin attempted on %s" % rec, rec.checkin
1976        )
1977
1978
1979class ResetOnReturnTest(PoolTestBase):
1980    def _fixture(self, **kw):
1981        dbapi = Mock()
1982        return (
1983            dbapi,
1984            pool.QueuePool(creator=lambda: dbapi.connect("foo.db"), **kw),
1985        )
1986
1987    def test_plain_rollback(self):
1988        dbapi, p = self._fixture(reset_on_return="rollback")
1989
1990        c1 = p.connect()
1991        c1.close()
1992        assert dbapi.connect().rollback.called
1993        assert not dbapi.connect().commit.called
1994
1995    def test_plain_commit(self):
1996        dbapi, p = self._fixture(reset_on_return="commit")
1997
1998        c1 = p.connect()
1999        c1.close()
2000        assert not dbapi.connect().rollback.called
2001        assert dbapi.connect().commit.called
2002
2003    def test_plain_none(self):
2004        dbapi, p = self._fixture(reset_on_return=None)
2005
2006        c1 = p.connect()
2007        c1.close()
2008        assert not dbapi.connect().rollback.called
2009        assert not dbapi.connect().commit.called
2010
2011    def test_agent_rollback(self):
2012        dbapi, p = self._fixture(reset_on_return="rollback")
2013
2014        class Agent(object):
2015            def __init__(self, conn):
2016                self.conn = conn
2017
2018            def rollback(self):
2019                self.conn.special_rollback()
2020
2021            def commit(self):
2022                self.conn.special_commit()
2023
2024        c1 = p.connect()
2025        c1._reset_agent = Agent(c1)
2026        c1.close()
2027
2028        assert dbapi.connect().special_rollback.called
2029        assert not dbapi.connect().special_commit.called
2030
2031        assert not dbapi.connect().rollback.called
2032        assert not dbapi.connect().commit.called
2033
2034        c1 = p.connect()
2035        c1.close()
2036        eq_(dbapi.connect().special_rollback.call_count, 1)
2037        eq_(dbapi.connect().special_commit.call_count, 0)
2038
2039        assert dbapi.connect().rollback.called
2040        assert not dbapi.connect().commit.called
2041
2042    def test_agent_commit(self):
2043        dbapi, p = self._fixture(reset_on_return="commit")
2044
2045        class Agent(object):
2046            def __init__(self, conn):
2047                self.conn = conn
2048
2049            def rollback(self):
2050                self.conn.special_rollback()
2051
2052            def commit(self):
2053                self.conn.special_commit()
2054
2055        c1 = p.connect()
2056        c1._reset_agent = Agent(c1)
2057        c1.close()
2058        assert not dbapi.connect().special_rollback.called
2059        assert dbapi.connect().special_commit.called
2060
2061        assert not dbapi.connect().rollback.called
2062        assert not dbapi.connect().commit.called
2063
2064        c1 = p.connect()
2065        c1.close()
2066
2067        eq_(dbapi.connect().special_rollback.call_count, 0)
2068        eq_(dbapi.connect().special_commit.call_count, 1)
2069        assert not dbapi.connect().rollback.called
2070        assert dbapi.connect().commit.called
2071
2072    def test_reset_agent_disconnect(self):
2073        dbapi, p = self._fixture(reset_on_return="rollback")
2074
2075        class Agent(object):
2076            def __init__(self, conn):
2077                self.conn = conn
2078
2079            def rollback(self):
2080                p._invalidate(self.conn)
2081                raise Exception("hi")
2082
2083            def commit(self):
2084                self.conn.commit()
2085
2086        c1 = p.connect()
2087        c1._reset_agent = Agent(c1)
2088        c1.close()
2089
2090        # no warning raised.  We know it would warn due to
2091        # QueuePoolTest.test_no_double_checkin
2092
2093
2094class SingletonThreadPoolTest(PoolTestBase):
2095    @testing.requires.threading_with_mock
2096    def test_cleanup(self):
2097        self._test_cleanup(False)
2098
2099    #   TODO: the SingletonThreadPool cleanup method
2100    #   has an unfixed race condition within the "cleanup" system that
2101    #   leads to this test being off by one connection under load; in any
2102    #   case, this connection will be closed once it is garbage collected.
2103    #   this pool is not a production-level pool and is only used for the
2104    #   SQLite "memory" connection, and is not very useful under actual
2105    #   multi-threaded conditions
2106    #    @testing.requires.threading_with_mock
2107    #    def test_cleanup_no_gc(self):
2108    #       self._test_cleanup(True)
2109
2110    def _test_cleanup(self, strong_refs):
2111        """test that the pool's connections are OK after cleanup() has
2112        been called."""
2113
2114        dbapi = MockDBAPI()
2115
2116        lock = threading.Lock()
2117
2118        def creator():
2119            # the mock iterator isn't threadsafe...
2120            with lock:
2121                return dbapi.connect()
2122
2123        p = pool.SingletonThreadPool(creator=creator, pool_size=3)
2124
2125        if strong_refs:
2126            sr = set()
2127
2128            def _conn():
2129                c = p.connect()
2130                sr.add(c.connection)
2131                return c
2132
2133        else:
2134
2135            def _conn():
2136                return p.connect()
2137
2138        def checkout():
2139            for x in range(10):
2140                c = _conn()
2141                assert c
2142                c.cursor()
2143                c.close()
2144                time.sleep(0.1)
2145
2146        threads = []
2147        for i in range(10):
2148            th = threading.Thread(target=checkout)
2149            th.start()
2150            threads.append(th)
2151        for th in threads:
2152            th.join(join_timeout)
2153        eq_(len(p._all_conns), 3)
2154
2155        if strong_refs:
2156            still_opened = len([c for c in sr if not c.close.call_count])
2157            eq_(still_opened, 3)
2158
2159
2160class AssertionPoolTest(PoolTestBase):
2161    def test_connect_error(self):
2162        dbapi = MockDBAPI()
2163        p = pool.AssertionPool(creator=lambda: dbapi.connect("foo.db"))
2164        c1 = p.connect()  # noqa
2165        assert_raises(AssertionError, p.connect)
2166
2167    def test_connect_multiple(self):
2168        dbapi = MockDBAPI()
2169        p = pool.AssertionPool(creator=lambda: dbapi.connect("foo.db"))
2170        c1 = p.connect()
2171        c1.close()
2172        c2 = p.connect()
2173        c2.close()
2174
2175        c3 = p.connect()  # noqa
2176        assert_raises(AssertionError, p.connect)
2177
2178
2179class NullPoolTest(PoolTestBase):
2180    def test_reconnect(self):
2181        dbapi = MockDBAPI()
2182        p = pool.NullPool(creator=lambda: dbapi.connect("foo.db"))
2183        c1 = p.connect()
2184
2185        c1.close()
2186        c1 = None
2187
2188        c1 = p.connect()
2189        c1.invalidate()
2190        c1 = None
2191
2192        c1 = p.connect()
2193        dbapi.connect.assert_has_calls(
2194            [call("foo.db"), call("foo.db")], any_order=True
2195        )
2196
2197
2198class StaticPoolTest(PoolTestBase):
2199    def test_recreate(self):
2200        dbapi = MockDBAPI()
2201
2202        def creator():
2203            return dbapi.connect("foo.db")
2204
2205        p = pool.StaticPool(creator)
2206        p2 = p.recreate()
2207        assert p._creator is p2._creator
2208
2209
2210class CreatorCompatibilityTest(PoolTestBase):
2211    def test_creator_callable_outside_noarg(self):
2212        e = testing_engine()
2213
2214        creator = e.pool._creator
2215        try:
2216            conn = creator()
2217        finally:
2218            conn.close()
2219
2220    def test_creator_callable_outside_witharg(self):
2221        e = testing_engine()
2222
2223        creator = e.pool._creator
2224        try:
2225            conn = creator(Mock())
2226        finally:
2227            conn.close()
2228
2229    def test_creator_patching_arg_to_noarg(self):
2230        e = testing_engine()
2231        creator = e.pool._creator
2232        try:
2233            # the creator is the two-arg form
2234            conn = creator(Mock())
2235        finally:
2236            conn.close()
2237
2238        def mock_create():
2239            return creator()
2240
2241        conn = e.connect()
2242        conn.invalidate()
2243        conn.close()
2244
2245        # test that the 'should_wrap_creator' status
2246        # will dynamically switch if the _creator is monkeypatched.
2247
2248        # patch it with a zero-arg form
2249        with patch.object(e.pool, "_creator", mock_create):
2250            conn = e.connect()
2251            conn.invalidate()
2252            conn.close()
2253
2254        conn = e.connect()
2255        conn.close()
2256