1import re
2import time
3
4import sqlalchemy as tsa
5from sqlalchemy import column
6from sqlalchemy import create_engine
7from sqlalchemy import engine_from_config
8from sqlalchemy import event
9from sqlalchemy import ForeignKey
10from sqlalchemy import func
11from sqlalchemy import inspect
12from sqlalchemy import INT
13from sqlalchemy import Integer
14from sqlalchemy import literal
15from sqlalchemy import MetaData
16from sqlalchemy import pool
17from sqlalchemy import select
18from sqlalchemy import Sequence
19from sqlalchemy import String
20from sqlalchemy import testing
21from sqlalchemy import text
22from sqlalchemy import TypeDecorator
23from sqlalchemy import VARCHAR
24from sqlalchemy.engine.base import Engine
25from sqlalchemy.interfaces import ConnectionProxy
26from sqlalchemy.testing import assert_raises_message
27from sqlalchemy.testing import engines
28from sqlalchemy.testing import eq_
29from sqlalchemy.testing import fixtures
30from sqlalchemy.testing.engines import testing_engine
31from sqlalchemy.testing.mock import call
32from sqlalchemy.testing.mock import Mock
33from sqlalchemy.testing.schema import Column
34from sqlalchemy.testing.schema import Table
35from sqlalchemy.testing.util import gc_collect
36from sqlalchemy.testing.util import lazy_gc
37from .test_parseconnect import mock_dbapi
38
39tlengine = None
40
41
42class SomeException(Exception):
43    pass
44
45
46def _tlengine_deprecated():
47    return testing.expect_deprecated(
48        "The 'threadlocal' engine strategy is deprecated"
49    )
50
51
52class TableNamesOrderByTest(fixtures.TestBase):
53    @testing.provide_metadata
54    def test_order_by_foreign_key(self):
55        Table(
56            "t1",
57            self.metadata,
58            Column("id", Integer, primary_key=True),
59            test_needs_acid=True,
60        )
61        Table(
62            "t2",
63            self.metadata,
64            Column("id", Integer, primary_key=True),
65            Column("t1id", Integer, ForeignKey("t1.id")),
66            test_needs_acid=True,
67        )
68        Table(
69            "t3",
70            self.metadata,
71            Column("id", Integer, primary_key=True),
72            Column("t2id", Integer, ForeignKey("t2.id")),
73            test_needs_acid=True,
74        )
75        self.metadata.create_all()
76        insp = inspect(testing.db)
77        with testing.expect_deprecated(
78            "The get_table_names.order_by parameter is deprecated "
79        ):
80            tnames = insp.get_table_names(order_by="foreign_key")
81        eq_(tnames, ["t1", "t2", "t3"])
82
83
84class CreateEngineTest(fixtures.TestBase):
85    def test_pool_threadlocal_from_config(self):
86        dbapi = mock_dbapi
87
88        config = {
89            "sqlalchemy.url": "postgresql://scott:tiger@somehost/test",
90            "sqlalchemy.pool_threadlocal": "false",
91        }
92
93        e = engine_from_config(config, module=dbapi, _initialize=False)
94        eq_(e.pool._use_threadlocal, False)
95
96        config = {
97            "sqlalchemy.url": "postgresql://scott:tiger@somehost/test",
98            "sqlalchemy.pool_threadlocal": "true",
99        }
100
101        with testing.expect_deprecated(
102            "The Pool.use_threadlocal parameter is deprecated"
103        ):
104            e = engine_from_config(config, module=dbapi, _initialize=False)
105        eq_(e.pool._use_threadlocal, True)
106
107
108class RecycleTest(fixtures.TestBase):
109    __backend__ = True
110
111    def test_basic(self):
112        with testing.expect_deprecated(
113            "The Pool.use_threadlocal parameter is deprecated"
114        ):
115            engine = engines.reconnecting_engine(
116                options={"pool_threadlocal": True}
117            )
118
119        with testing.expect_deprecated(
120            r"The Engine.contextual_connect\(\) method is deprecated"
121        ):
122            conn = engine.contextual_connect()
123        eq_(conn.execute(select([1])).scalar(), 1)
124        conn.close()
125
126        # set the pool recycle down to 1.
127        # we aren't doing this inline with the
128        # engine create since cx_oracle takes way
129        # too long to create the 1st connection and don't
130        # want to build a huge delay into this test.
131
132        engine.pool._recycle = 1
133
134        # kill the DB connection
135        engine.test_shutdown()
136
137        # wait until past the recycle period
138        time.sleep(2)
139
140        # can connect, no exception
141        with testing.expect_deprecated(
142            r"The Engine.contextual_connect\(\) method is deprecated"
143        ):
144            conn = engine.contextual_connect()
145        eq_(conn.execute(select([1])).scalar(), 1)
146        conn.close()
147
148
149class TLTransactionTest(fixtures.TestBase):
150    __requires__ = ("ad_hoc_engines",)
151    __backend__ = True
152
153    @classmethod
154    def setup_class(cls):
155        global users, metadata, tlengine
156
157        with _tlengine_deprecated():
158            tlengine = testing_engine(options=dict(strategy="threadlocal"))
159        metadata = MetaData()
160        users = Table(
161            "query_users",
162            metadata,
163            Column(
164                "user_id",
165                INT,
166                Sequence("query_users_id_seq", optional=True),
167                primary_key=True,
168            ),
169            Column("user_name", VARCHAR(20)),
170            test_needs_acid=True,
171        )
172        metadata.create_all(tlengine)
173
174    def teardown(self):
175        tlengine.execute(users.delete()).close()
176
177    @classmethod
178    def teardown_class(cls):
179        tlengine.close()
180        metadata.drop_all(tlengine)
181        tlengine.dispose()
182
183    def setup(self):
184
185        # ensure tests start with engine closed
186
187        tlengine.close()
188
189    @testing.crashes(
190        "oracle", "TNS error of unknown origin occurs on the buildbot."
191    )
192    def test_rollback_no_trans(self):
193        with _tlengine_deprecated():
194            tlengine = testing_engine(options=dict(strategy="threadlocal"))
195
196        # shouldn't fail
197        tlengine.rollback()
198
199        tlengine.begin()
200        tlengine.rollback()
201
202        # shouldn't fail
203        tlengine.rollback()
204
205    def test_commit_no_trans(self):
206        with _tlengine_deprecated():
207            tlengine = testing_engine(options=dict(strategy="threadlocal"))
208
209        # shouldn't fail
210        tlengine.commit()
211
212        tlengine.begin()
213        tlengine.rollback()
214
215        # shouldn't fail
216        tlengine.commit()
217
218    def test_prepare_no_trans(self):
219        with _tlengine_deprecated():
220            tlengine = testing_engine(options=dict(strategy="threadlocal"))
221
222        # shouldn't fail
223        tlengine.prepare()
224
225        tlengine.begin()
226        tlengine.rollback()
227
228        # shouldn't fail
229        tlengine.prepare()
230
231    def test_connection_close(self):
232        """test that when connections are closed for real, transactions
233        are rolled back and disposed."""
234
235        c = tlengine.contextual_connect()
236        c.begin()
237        assert c.in_transaction()
238        c.close()
239        assert not c.in_transaction()
240
241    def test_transaction_close(self):
242        c = tlengine.contextual_connect()
243        t = c.begin()
244        tlengine.execute(users.insert(), user_id=1, user_name="user1")
245        tlengine.execute(users.insert(), user_id=2, user_name="user2")
246        t2 = c.begin()
247        tlengine.execute(users.insert(), user_id=3, user_name="user3")
248        tlengine.execute(users.insert(), user_id=4, user_name="user4")
249        t2.close()
250        result = c.execute("select * from query_users")
251        assert len(result.fetchall()) == 4
252        t.close()
253        external_connection = tlengine.connect()
254        result = external_connection.execute("select * from query_users")
255        try:
256            assert len(result.fetchall()) == 0
257        finally:
258            c.close()
259            external_connection.close()
260
261    def test_rollback(self):
262        """test a basic rollback"""
263
264        tlengine.begin()
265        tlengine.execute(users.insert(), user_id=1, user_name="user1")
266        tlengine.execute(users.insert(), user_id=2, user_name="user2")
267        tlengine.execute(users.insert(), user_id=3, user_name="user3")
268        tlengine.rollback()
269        external_connection = tlengine.connect()
270        result = external_connection.execute("select * from query_users")
271        try:
272            assert len(result.fetchall()) == 0
273        finally:
274            external_connection.close()
275
276    def test_commit(self):
277        """test a basic commit"""
278
279        tlengine.begin()
280        tlengine.execute(users.insert(), user_id=1, user_name="user1")
281        tlengine.execute(users.insert(), user_id=2, user_name="user2")
282        tlengine.execute(users.insert(), user_id=3, user_name="user3")
283        tlengine.commit()
284        external_connection = tlengine.connect()
285        result = external_connection.execute("select * from query_users")
286        try:
287            assert len(result.fetchall()) == 3
288        finally:
289            external_connection.close()
290
291    def test_with_interface(self):
292        trans = tlengine.begin()
293        tlengine.execute(users.insert(), user_id=1, user_name="user1")
294        tlengine.execute(users.insert(), user_id=2, user_name="user2")
295        trans.commit()
296
297        trans = tlengine.begin()
298        tlengine.execute(users.insert(), user_id=3, user_name="user3")
299        trans.__exit__(Exception, "fake", None)
300        trans = tlengine.begin()
301        tlengine.execute(users.insert(), user_id=4, user_name="user4")
302        trans.__exit__(None, None, None)
303        eq_(
304            tlengine.execute(
305                users.select().order_by(users.c.user_id)
306            ).fetchall(),
307            [(1, "user1"), (2, "user2"), (4, "user4")],
308        )
309
310    def test_commits(self):
311        connection = tlengine.connect()
312        assert (
313            connection.execute("select count(*) from query_users").scalar()
314            == 0
315        )
316        connection.close()
317        connection = tlengine.contextual_connect()
318        transaction = connection.begin()
319        connection.execute(users.insert(), user_id=1, user_name="user1")
320        transaction.commit()
321        transaction = connection.begin()
322        connection.execute(users.insert(), user_id=2, user_name="user2")
323        connection.execute(users.insert(), user_id=3, user_name="user3")
324        transaction.commit()
325        transaction = connection.begin()
326        result = connection.execute("select * from query_users")
327        rows = result.fetchall()
328        assert len(rows) == 3, "expected 3 got %d" % len(rows)
329        transaction.commit()
330        connection.close()
331
332    def test_rollback_off_conn(self):
333
334        # test that a TLTransaction opened off a TLConnection allows
335        # that TLConnection to be aware of the transactional context
336
337        conn = tlengine.contextual_connect()
338        trans = conn.begin()
339        conn.execute(users.insert(), user_id=1, user_name="user1")
340        conn.execute(users.insert(), user_id=2, user_name="user2")
341        conn.execute(users.insert(), user_id=3, user_name="user3")
342        trans.rollback()
343        external_connection = tlengine.connect()
344        result = external_connection.execute("select * from query_users")
345        try:
346            assert len(result.fetchall()) == 0
347        finally:
348            conn.close()
349            external_connection.close()
350
351    def test_morerollback_off_conn(self):
352
353        # test that an existing TLConnection automatically takes place
354        # in a TLTransaction opened on a second TLConnection
355
356        conn = tlengine.contextual_connect()
357        conn2 = tlengine.contextual_connect()
358        trans = conn2.begin()
359        conn.execute(users.insert(), user_id=1, user_name="user1")
360        conn.execute(users.insert(), user_id=2, user_name="user2")
361        conn.execute(users.insert(), user_id=3, user_name="user3")
362        trans.rollback()
363        external_connection = tlengine.connect()
364        result = external_connection.execute("select * from query_users")
365        try:
366            assert len(result.fetchall()) == 0
367        finally:
368            conn.close()
369            conn2.close()
370            external_connection.close()
371
372    def test_commit_off_connection(self):
373        conn = tlengine.contextual_connect()
374        trans = conn.begin()
375        conn.execute(users.insert(), user_id=1, user_name="user1")
376        conn.execute(users.insert(), user_id=2, user_name="user2")
377        conn.execute(users.insert(), user_id=3, user_name="user3")
378        trans.commit()
379        external_connection = tlengine.connect()
380        result = external_connection.execute("select * from query_users")
381        try:
382            assert len(result.fetchall()) == 3
383        finally:
384            conn.close()
385            external_connection.close()
386
387    def test_nesting_rollback(self):
388        """tests nesting of transactions, rollback at the end"""
389
390        external_connection = tlengine.connect()
391        self.assert_(
392            external_connection.connection
393            is not tlengine.contextual_connect().connection
394        )
395        tlengine.begin()
396        tlengine.execute(users.insert(), user_id=1, user_name="user1")
397        tlengine.execute(users.insert(), user_id=2, user_name="user2")
398        tlengine.execute(users.insert(), user_id=3, user_name="user3")
399        tlengine.begin()
400        tlengine.execute(users.insert(), user_id=4, user_name="user4")
401        tlengine.execute(users.insert(), user_id=5, user_name="user5")
402        tlengine.commit()
403        tlengine.rollback()
404        try:
405            self.assert_(
406                external_connection.scalar("select count(*) from query_users")
407                == 0
408            )
409        finally:
410            external_connection.close()
411
412    def test_nesting_commit(self):
413        """tests nesting of transactions, commit at the end."""
414
415        external_connection = tlengine.connect()
416        self.assert_(
417            external_connection.connection
418            is not tlengine.contextual_connect().connection
419        )
420        tlengine.begin()
421        tlengine.execute(users.insert(), user_id=1, user_name="user1")
422        tlengine.execute(users.insert(), user_id=2, user_name="user2")
423        tlengine.execute(users.insert(), user_id=3, user_name="user3")
424        tlengine.begin()
425        tlengine.execute(users.insert(), user_id=4, user_name="user4")
426        tlengine.execute(users.insert(), user_id=5, user_name="user5")
427        tlengine.commit()
428        tlengine.commit()
429        try:
430            self.assert_(
431                external_connection.scalar("select count(*) from query_users")
432                == 5
433            )
434        finally:
435            external_connection.close()
436
437    def test_mixed_nesting(self):
438        """tests nesting of transactions off the TLEngine directly
439        inside of transactions off the connection from the TLEngine"""
440
441        external_connection = tlengine.connect()
442        self.assert_(
443            external_connection.connection
444            is not tlengine.contextual_connect().connection
445        )
446        conn = tlengine.contextual_connect()
447        trans = conn.begin()
448        trans2 = conn.begin()
449        tlengine.execute(users.insert(), user_id=1, user_name="user1")
450        tlengine.execute(users.insert(), user_id=2, user_name="user2")
451        tlengine.execute(users.insert(), user_id=3, user_name="user3")
452        tlengine.begin()
453        tlengine.execute(users.insert(), user_id=4, user_name="user4")
454        tlengine.begin()
455        tlengine.execute(users.insert(), user_id=5, user_name="user5")
456        tlengine.execute(users.insert(), user_id=6, user_name="user6")
457        tlengine.execute(users.insert(), user_id=7, user_name="user7")
458        tlengine.commit()
459        tlengine.execute(users.insert(), user_id=8, user_name="user8")
460        tlengine.commit()
461        trans2.commit()
462        trans.rollback()
463        conn.close()
464        try:
465            self.assert_(
466                external_connection.scalar("select count(*) from query_users")
467                == 0
468            )
469        finally:
470            external_connection.close()
471
472    def test_more_mixed_nesting(self):
473        """tests nesting of transactions off the connection from the
474        TLEngine inside of transactions off the TLEngine directly."""
475
476        external_connection = tlengine.connect()
477        self.assert_(
478            external_connection.connection
479            is not tlengine.contextual_connect().connection
480        )
481        tlengine.begin()
482        connection = tlengine.contextual_connect()
483        connection.execute(users.insert(), user_id=1, user_name="user1")
484        tlengine.begin()
485        connection.execute(users.insert(), user_id=2, user_name="user2")
486        connection.execute(users.insert(), user_id=3, user_name="user3")
487        trans = connection.begin()
488        connection.execute(users.insert(), user_id=4, user_name="user4")
489        connection.execute(users.insert(), user_id=5, user_name="user5")
490        trans.commit()
491        tlengine.commit()
492        tlengine.rollback()
493        connection.close()
494        try:
495            self.assert_(
496                external_connection.scalar("select count(*) from query_users")
497                == 0
498            )
499        finally:
500            external_connection.close()
501
502    @testing.requires.savepoints
503    def test_nested_subtransaction_rollback(self):
504        tlengine.begin()
505        tlengine.execute(users.insert(), user_id=1, user_name="user1")
506        tlengine.begin_nested()
507        tlengine.execute(users.insert(), user_id=2, user_name="user2")
508        tlengine.rollback()
509        tlengine.execute(users.insert(), user_id=3, user_name="user3")
510        tlengine.commit()
511        tlengine.close()
512        eq_(
513            tlengine.execute(
514                select([users.c.user_id]).order_by(users.c.user_id)
515            ).fetchall(),
516            [(1,), (3,)],
517        )
518        tlengine.close()
519
520    @testing.requires.savepoints
521    @testing.crashes(
522        "oracle+zxjdbc",
523        "Errors out and causes subsequent tests to " "deadlock",
524    )
525    def test_nested_subtransaction_commit(self):
526        tlengine.begin()
527        tlengine.execute(users.insert(), user_id=1, user_name="user1")
528        tlengine.begin_nested()
529        tlengine.execute(users.insert(), user_id=2, user_name="user2")
530        tlengine.commit()
531        tlengine.execute(users.insert(), user_id=3, user_name="user3")
532        tlengine.commit()
533        tlengine.close()
534        eq_(
535            tlengine.execute(
536                select([users.c.user_id]).order_by(users.c.user_id)
537            ).fetchall(),
538            [(1,), (2,), (3,)],
539        )
540        tlengine.close()
541
542    @testing.requires.savepoints
543    def test_rollback_to_subtransaction(self):
544        tlengine.begin()
545        tlengine.execute(users.insert(), user_id=1, user_name="user1")
546        tlengine.begin_nested()
547        tlengine.execute(users.insert(), user_id=2, user_name="user2")
548        tlengine.begin()
549        tlengine.execute(users.insert(), user_id=3, user_name="user3")
550        tlengine.rollback()
551        tlengine.rollback()
552        tlengine.execute(users.insert(), user_id=4, user_name="user4")
553        tlengine.commit()
554        tlengine.close()
555        eq_(
556            tlengine.execute(
557                select([users.c.user_id]).order_by(users.c.user_id)
558            ).fetchall(),
559            [(1,), (4,)],
560        )
561        tlengine.close()
562
563    def test_connections(self):
564        """tests that contextual_connect is threadlocal"""
565
566        c1 = tlengine.contextual_connect()
567        c2 = tlengine.contextual_connect()
568        assert c1.connection is c2.connection
569        c2.close()
570        assert not c1.closed
571        assert not tlengine.closed
572
573    @testing.requires.independent_cursors
574    def test_result_closing(self):
575        """tests that contextual_connect is threadlocal"""
576
577        r1 = tlengine.execute(select([1]))
578        r2 = tlengine.execute(select([1]))
579        r1.fetchone()
580        r2.fetchone()
581        r1.close()
582        assert r2.connection is r1.connection
583        assert not r2.connection.closed
584        assert not tlengine.closed
585
586        # close again, nothing happens since resultproxy calls close()
587        # only once
588
589        r1.close()
590        assert r2.connection is r1.connection
591        assert not r2.connection.closed
592        assert not tlengine.closed
593        r2.close()
594        assert r2.connection.closed
595        assert tlengine.closed
596
597    @testing.crashes(
598        "oracle+cx_oracle", "intermittent failures on the buildbot"
599    )
600    def test_dispose(self):
601        with _tlengine_deprecated():
602            eng = testing_engine(options=dict(strategy="threadlocal"))
603        eng.execute(select([1]))
604        eng.dispose()
605        eng.execute(select([1]))
606
607    @testing.requires.two_phase_transactions
608    def test_two_phase_transaction(self):
609        tlengine.begin_twophase()
610        tlengine.execute(users.insert(), user_id=1, user_name="user1")
611        tlengine.prepare()
612        tlengine.commit()
613        tlengine.begin_twophase()
614        tlengine.execute(users.insert(), user_id=2, user_name="user2")
615        tlengine.commit()
616        tlengine.begin_twophase()
617        tlengine.execute(users.insert(), user_id=3, user_name="user3")
618        tlengine.rollback()
619        tlengine.begin_twophase()
620        tlengine.execute(users.insert(), user_id=4, user_name="user4")
621        tlengine.prepare()
622        tlengine.rollback()
623        eq_(
624            tlengine.execute(
625                select([users.c.user_id]).order_by(users.c.user_id)
626            ).fetchall(),
627            [(1,), (2,)],
628        )
629
630
631class ConvenienceExecuteTest(fixtures.TablesTest):
632    __backend__ = True
633
634    @classmethod
635    def define_tables(cls, metadata):
636        cls.table = Table(
637            "exec_test",
638            metadata,
639            Column("a", Integer),
640            Column("b", Integer),
641            test_needs_acid=True,
642        )
643
644    def _trans_fn(self, is_transaction=False):
645        def go(conn, x, value=None):
646            if is_transaction:
647                conn = conn.connection
648            conn.execute(self.table.insert().values(a=x, b=value))
649
650        return go
651
652    def _trans_rollback_fn(self, is_transaction=False):
653        def go(conn, x, value=None):
654            if is_transaction:
655                conn = conn.connection
656            conn.execute(self.table.insert().values(a=x, b=value))
657            raise SomeException("breakage")
658
659        return go
660
661    def _assert_no_data(self):
662        eq_(
663            testing.db.scalar(
664                select([func.count("*")]).select_from(self.table)
665            ),
666            0,
667        )
668
669    def _assert_fn(self, x, value=None):
670        eq_(testing.db.execute(self.table.select()).fetchall(), [(x, value)])
671
672    def test_transaction_tlocal_engine_ctx_commit(self):
673        fn = self._trans_fn()
674        with _tlengine_deprecated():
675            engine = engines.testing_engine(
676                options=dict(strategy="threadlocal", pool=testing.db.pool)
677            )
678        ctx = engine.begin()
679        testing.run_as_contextmanager(ctx, fn, 5, value=8)
680        self._assert_fn(5, value=8)
681
682    def test_transaction_tlocal_engine_ctx_rollback(self):
683        fn = self._trans_rollback_fn()
684        with _tlengine_deprecated():
685            engine = engines.testing_engine(
686                options=dict(strategy="threadlocal", pool=testing.db.pool)
687            )
688        ctx = engine.begin()
689        assert_raises_message(
690            Exception,
691            "breakage",
692            testing.run_as_contextmanager,
693            ctx,
694            fn,
695            5,
696            value=8,
697        )
698        self._assert_no_data()
699
700
701def _proxy_execute_deprecated():
702    return (
703        testing.expect_deprecated("ConnectionProxy.execute is deprecated."),
704        testing.expect_deprecated(
705            "ConnectionProxy.cursor_execute is deprecated."
706        ),
707    )
708
709
710class ProxyConnectionTest(fixtures.TestBase):
711
712    """These are the same tests as EngineEventsTest, except using
713    the deprecated ConnectionProxy interface.
714
715    """
716
717    __requires__ = ("ad_hoc_engines",)
718    __prefer_requires__ = ("two_phase_transactions",)
719
720    @testing.uses_deprecated(r".*Use event.listen")
721    @testing.fails_on("firebird", "Data type unknown")
722    def test_proxy(self):
723
724        stmts = []
725        cursor_stmts = []
726
727        class MyProxy(ConnectionProxy):
728            def execute(
729                self, conn, execute, clauseelement, *multiparams, **params
730            ):
731                stmts.append((str(clauseelement), params, multiparams))
732                return execute(clauseelement, *multiparams, **params)
733
734            def cursor_execute(
735                self,
736                execute,
737                cursor,
738                statement,
739                parameters,
740                context,
741                executemany,
742            ):
743                cursor_stmts.append((str(statement), parameters, None))
744                return execute(cursor, statement, parameters, context)
745
746        def assert_stmts(expected, received):
747            for stmt, params, posn in expected:
748                if not received:
749                    assert False, "Nothing available for stmt: %s" % stmt
750                while received:
751                    teststmt, testparams, testmultiparams = received.pop(0)
752                    teststmt = (
753                        re.compile(r"[\n\t ]+", re.M)
754                        .sub(" ", teststmt)
755                        .strip()
756                    )
757                    if teststmt.startswith(stmt) and (
758                        testparams == params or testparams == posn
759                    ):
760                        break
761
762        with testing.expect_deprecated(
763            "ConnectionProxy.execute is deprecated.",
764            "ConnectionProxy.cursor_execute is deprecated.",
765        ):
766            plain_engine = engines.testing_engine(
767                options=dict(implicit_returning=False, proxy=MyProxy())
768            )
769
770        with testing.expect_deprecated(
771            "ConnectionProxy.execute is deprecated.",
772            "ConnectionProxy.cursor_execute is deprecated.",
773            "The 'threadlocal' engine strategy is deprecated",
774        ):
775
776            tl_engine = engines.testing_engine(
777                options=dict(
778                    implicit_returning=False,
779                    proxy=MyProxy(),
780                    strategy="threadlocal",
781                )
782            )
783
784        for engine in (plain_engine, tl_engine):
785            m = MetaData(engine)
786            t1 = Table(
787                "t1",
788                m,
789                Column("c1", Integer, primary_key=True),
790                Column(
791                    "c2",
792                    String(50),
793                    default=func.lower("Foo"),
794                    primary_key=True,
795                ),
796            )
797            m.create_all()
798            try:
799                t1.insert().execute(c1=5, c2="some data")
800                t1.insert().execute(c1=6)
801                eq_(
802                    engine.execute("select * from t1").fetchall(),
803                    [(5, "some data"), (6, "foo")],
804                )
805            finally:
806                m.drop_all()
807            engine.dispose()
808            compiled = [
809                ("CREATE TABLE t1", {}, None),
810                (
811                    "INSERT INTO t1 (c1, c2)",
812                    {"c2": "some data", "c1": 5},
813                    None,
814                ),
815                ("INSERT INTO t1 (c1, c2)", {"c1": 6}, None),
816                ("select * from t1", {}, None),
817                ("DROP TABLE t1", {}, None),
818            ]
819
820            cursor = [
821                ("CREATE TABLE t1", {}, ()),
822                (
823                    "INSERT INTO t1 (c1, c2)",
824                    {"c2": "some data", "c1": 5},
825                    (5, "some data"),
826                ),
827                ("SELECT lower", {"lower_1": "Foo"}, ("Foo",)),
828                (
829                    "INSERT INTO t1 (c1, c2)",
830                    {"c2": "foo", "c1": 6},
831                    (6, "foo"),
832                ),
833                ("select * from t1", {}, ()),
834                ("DROP TABLE t1", {}, ()),
835            ]
836
837            assert_stmts(compiled, stmts)
838            assert_stmts(cursor, cursor_stmts)
839
840    @testing.uses_deprecated(r".*Use event.listen")
841    def test_options(self):
842        canary = []
843
844        class TrackProxy(ConnectionProxy):
845            def __getattribute__(self, key):
846                fn = object.__getattribute__(self, key)
847
848                def go(*arg, **kw):
849                    canary.append(fn.__name__)
850                    return fn(*arg, **kw)
851
852                return go
853
854        with testing.expect_deprecated(
855            *[
856                "ConnectionProxy.%s is deprecated" % name
857                for name in [
858                    "execute",
859                    "cursor_execute",
860                    "begin",
861                    "rollback",
862                    "commit",
863                    "savepoint",
864                    "rollback_savepoint",
865                    "release_savepoint",
866                    "begin_twophase",
867                    "prepare_twophase",
868                    "rollback_twophase",
869                    "commit_twophase",
870                ]
871            ]
872        ):
873            engine = engines.testing_engine(options={"proxy": TrackProxy()})
874        conn = engine.connect()
875        c2 = conn.execution_options(foo="bar")
876        eq_(c2._execution_options, {"foo": "bar"})
877        c2.execute(select([1]))
878        c3 = c2.execution_options(bar="bat")
879        eq_(c3._execution_options, {"foo": "bar", "bar": "bat"})
880        eq_(canary, ["execute", "cursor_execute"])
881
882    @testing.uses_deprecated(r".*Use event.listen")
883    def test_transactional(self):
884        canary = []
885
886        class TrackProxy(ConnectionProxy):
887            def __getattribute__(self, key):
888                fn = object.__getattribute__(self, key)
889
890                def go(*arg, **kw):
891                    canary.append(fn.__name__)
892                    return fn(*arg, **kw)
893
894                return go
895
896        with testing.expect_deprecated(
897            *[
898                "ConnectionProxy.%s is deprecated" % name
899                for name in [
900                    "execute",
901                    "cursor_execute",
902                    "begin",
903                    "rollback",
904                    "commit",
905                    "savepoint",
906                    "rollback_savepoint",
907                    "release_savepoint",
908                    "begin_twophase",
909                    "prepare_twophase",
910                    "rollback_twophase",
911                    "commit_twophase",
912                ]
913            ]
914        ):
915            engine = engines.testing_engine(options={"proxy": TrackProxy()})
916        conn = engine.connect()
917        trans = conn.begin()
918        conn.execute(select([1]))
919        trans.rollback()
920        trans = conn.begin()
921        conn.execute(select([1]))
922        trans.commit()
923
924        eq_(
925            canary,
926            [
927                "begin",
928                "execute",
929                "cursor_execute",
930                "rollback",
931                "begin",
932                "execute",
933                "cursor_execute",
934                "commit",
935            ],
936        )
937
938    @testing.uses_deprecated(r".*Use event.listen")
939    @testing.requires.savepoints
940    @testing.requires.two_phase_transactions
941    def test_transactional_advanced(self):
942        canary = []
943
944        class TrackProxy(ConnectionProxy):
945            def __getattribute__(self, key):
946                fn = object.__getattribute__(self, key)
947
948                def go(*arg, **kw):
949                    canary.append(fn.__name__)
950                    return fn(*arg, **kw)
951
952                return go
953
954        with testing.expect_deprecated(
955            *[
956                "ConnectionProxy.%s is deprecated" % name
957                for name in [
958                    "execute",
959                    "cursor_execute",
960                    "begin",
961                    "rollback",
962                    "commit",
963                    "savepoint",
964                    "rollback_savepoint",
965                    "release_savepoint",
966                    "begin_twophase",
967                    "prepare_twophase",
968                    "rollback_twophase",
969                    "commit_twophase",
970                ]
971            ]
972        ):
973            engine = engines.testing_engine(options={"proxy": TrackProxy()})
974        conn = engine.connect()
975
976        trans = conn.begin()
977        trans2 = conn.begin_nested()
978        conn.execute(select([1]))
979        trans2.rollback()
980        trans2 = conn.begin_nested()
981        conn.execute(select([1]))
982        trans2.commit()
983        trans.rollback()
984
985        trans = conn.begin_twophase()
986        conn.execute(select([1]))
987        trans.prepare()
988        trans.commit()
989
990        canary = [t for t in canary if t not in ("cursor_execute", "execute")]
991        eq_(
992            canary,
993            [
994                "begin",
995                "savepoint",
996                "rollback_savepoint",
997                "savepoint",
998                "release_savepoint",
999                "rollback",
1000                "begin_twophase",
1001                "prepare_twophase",
1002                "commit_twophase",
1003            ],
1004        )
1005
1006
1007class HandleInvalidatedOnConnectTest(fixtures.TestBase):
1008    __requires__ = ("sqlite",)
1009
1010    def setUp(self):
1011        e = create_engine("sqlite://")
1012
1013        connection = Mock(get_server_version_info=Mock(return_value="5.0"))
1014
1015        def connect(*args, **kwargs):
1016            return connection
1017
1018        dbapi = Mock(
1019            sqlite_version_info=(99, 9, 9),
1020            version_info=(99, 9, 9),
1021            sqlite_version="99.9.9",
1022            paramstyle="named",
1023            connect=Mock(side_effect=connect),
1024        )
1025
1026        sqlite3 = e.dialect.dbapi
1027        dbapi.Error = (sqlite3.Error,)
1028        dbapi.ProgrammingError = sqlite3.ProgrammingError
1029
1030        self.dbapi = dbapi
1031        self.ProgrammingError = sqlite3.ProgrammingError
1032
1033    def test_dont_touch_non_dbapi_exception_on_contextual_connect(self):
1034        dbapi = self.dbapi
1035        dbapi.connect = Mock(side_effect=TypeError("I'm not a DBAPI error"))
1036
1037        e = create_engine("sqlite://", module=dbapi)
1038        e.dialect.is_disconnect = is_disconnect = Mock()
1039        with testing.expect_deprecated(
1040            r"The Engine.contextual_connect\(\) method is deprecated"
1041        ):
1042            assert_raises_message(
1043                TypeError, "I'm not a DBAPI error", e.contextual_connect
1044            )
1045        eq_(is_disconnect.call_count, 0)
1046
1047    def test_invalidate_on_contextual_connect(self):
1048        """test that is_disconnect() is called during connect.
1049
1050        interpretation of connection failures are not supported by
1051        every backend.
1052
1053        """
1054
1055        dbapi = self.dbapi
1056        dbapi.connect = Mock(
1057            side_effect=self.ProgrammingError(
1058                "Cannot operate on a closed database."
1059            )
1060        )
1061        e = create_engine("sqlite://", module=dbapi)
1062        try:
1063            with testing.expect_deprecated(
1064                r"The Engine.contextual_connect\(\) method is deprecated"
1065            ):
1066                e.contextual_connect()
1067            assert False
1068        except tsa.exc.DBAPIError as de:
1069            assert de.connection_invalidated
1070
1071
1072class HandleErrorTest(fixtures.TestBase):
1073    __requires__ = ("ad_hoc_engines",)
1074    __backend__ = True
1075
1076    def tearDown(self):
1077        Engine.dispatch._clear()
1078        Engine._has_events = False
1079
1080    def test_legacy_dbapi_error(self):
1081        engine = engines.testing_engine()
1082        canary = Mock()
1083
1084        with testing.expect_deprecated(
1085            r"The ConnectionEvents.dbapi_error\(\) event is deprecated"
1086        ):
1087            event.listen(engine, "dbapi_error", canary)
1088
1089        with engine.connect() as conn:
1090            try:
1091                conn.execute("SELECT FOO FROM I_DONT_EXIST")
1092                assert False
1093            except tsa.exc.DBAPIError as e:
1094                eq_(canary.mock_calls[0][1][5], e.orig)
1095                eq_(canary.mock_calls[0][1][2], "SELECT FOO FROM I_DONT_EXIST")
1096
1097    def test_legacy_dbapi_error_no_ad_hoc_context(self):
1098        engine = engines.testing_engine()
1099
1100        listener = Mock(return_value=None)
1101        with testing.expect_deprecated(
1102            r"The ConnectionEvents.dbapi_error\(\) event is deprecated"
1103        ):
1104            event.listen(engine, "dbapi_error", listener)
1105
1106        nope = SomeException("nope")
1107
1108        class MyType(TypeDecorator):
1109            impl = Integer
1110
1111            def process_bind_param(self, value, dialect):
1112                raise nope
1113
1114        with engine.connect() as conn:
1115            assert_raises_message(
1116                tsa.exc.StatementError,
1117                r"\(.*SomeException\) " r"nope\n\[SQL\: u?SELECT 1 ",
1118                conn.execute,
1119                select([1]).where(column("foo") == literal("bar", MyType())),
1120            )
1121        # no legacy event
1122        eq_(listener.mock_calls, [])
1123
1124    def test_legacy_dbapi_error_non_dbapi_error(self):
1125        engine = engines.testing_engine()
1126
1127        listener = Mock(return_value=None)
1128        with testing.expect_deprecated(
1129            r"The ConnectionEvents.dbapi_error\(\) event is deprecated"
1130        ):
1131            event.listen(engine, "dbapi_error", listener)
1132
1133        nope = TypeError("I'm not a DBAPI error")
1134        with engine.connect() as c:
1135            c.connection.cursor = Mock(
1136                return_value=Mock(execute=Mock(side_effect=nope))
1137            )
1138
1139            assert_raises_message(
1140                TypeError, "I'm not a DBAPI error", c.execute, "select "
1141            )
1142        # no legacy event
1143        eq_(listener.mock_calls, [])
1144
1145
1146def MockDBAPI():  # noqa
1147    def cursor():
1148        return Mock()
1149
1150    def connect(*arg, **kw):
1151        def close():
1152            conn.closed = True
1153
1154        # mock seems like it might have an issue logging
1155        # call_count correctly under threading, not sure.
1156        # adding a side_effect for close seems to help.
1157        conn = Mock(
1158            cursor=Mock(side_effect=cursor),
1159            close=Mock(side_effect=close),
1160            closed=False,
1161        )
1162        return conn
1163
1164    def shutdown(value):
1165        if value:
1166            db.connect = Mock(side_effect=Exception("connect failed"))
1167        else:
1168            db.connect = Mock(side_effect=connect)
1169        db.is_shutdown = value
1170
1171    db = Mock(
1172        connect=Mock(side_effect=connect), shutdown=shutdown, is_shutdown=False
1173    )
1174    return db
1175
1176
1177class PoolTestBase(fixtures.TestBase):
1178    def setup(self):
1179        pool.clear_managers()
1180        self._teardown_conns = []
1181
1182    def teardown(self):
1183        for ref in self._teardown_conns:
1184            conn = ref()
1185            if conn:
1186                conn.close()
1187
1188    @classmethod
1189    def teardown_class(cls):
1190        pool.clear_managers()
1191
1192    def _queuepool_fixture(self, **kw):
1193        dbapi, pool = self._queuepool_dbapi_fixture(**kw)
1194        return pool
1195
1196    def _queuepool_dbapi_fixture(self, **kw):
1197        dbapi = MockDBAPI()
1198        return (
1199            dbapi,
1200            pool.QueuePool(creator=lambda: dbapi.connect("foo.db"), **kw),
1201        )
1202
1203
1204class DeprecatedPoolListenerTest(PoolTestBase):
1205    @testing.requires.predictable_gc
1206    @testing.uses_deprecated(
1207        r".*Use the PoolEvents", r".*'listeners' argument .* is deprecated"
1208    )
1209    def test_listeners(self):
1210        class InstrumentingListener(object):
1211            def __init__(self):
1212                if hasattr(self, "connect"):
1213                    self.connect = self.inst_connect
1214                if hasattr(self, "first_connect"):
1215                    self.first_connect = self.inst_first_connect
1216                if hasattr(self, "checkout"):
1217                    self.checkout = self.inst_checkout
1218                if hasattr(self, "checkin"):
1219                    self.checkin = self.inst_checkin
1220                self.clear()
1221
1222            def clear(self):
1223                self.connected = []
1224                self.first_connected = []
1225                self.checked_out = []
1226                self.checked_in = []
1227
1228            def assert_total(self, conn, fconn, cout, cin):
1229                eq_(len(self.connected), conn)
1230                eq_(len(self.first_connected), fconn)
1231                eq_(len(self.checked_out), cout)
1232                eq_(len(self.checked_in), cin)
1233
1234            def assert_in(self, item, in_conn, in_fconn, in_cout, in_cin):
1235                eq_((item in self.connected), in_conn)
1236                eq_((item in self.first_connected), in_fconn)
1237                eq_((item in self.checked_out), in_cout)
1238                eq_((item in self.checked_in), in_cin)
1239
1240            def inst_connect(self, con, record):
1241                print("connect(%s, %s)" % (con, record))
1242                assert con is not None
1243                assert record is not None
1244                self.connected.append(con)
1245
1246            def inst_first_connect(self, con, record):
1247                print("first_connect(%s, %s)" % (con, record))
1248                assert con is not None
1249                assert record is not None
1250                self.first_connected.append(con)
1251
1252            def inst_checkout(self, con, record, proxy):
1253                print("checkout(%s, %s, %s)" % (con, record, proxy))
1254                assert con is not None
1255                assert record is not None
1256                assert proxy is not None
1257                self.checked_out.append(con)
1258
1259            def inst_checkin(self, con, record):
1260                print("checkin(%s, %s)" % (con, record))
1261                # con can be None if invalidated
1262                assert record is not None
1263                self.checked_in.append(con)
1264
1265        class ListenAll(tsa.interfaces.PoolListener, InstrumentingListener):
1266            pass
1267
1268        class ListenConnect(InstrumentingListener):
1269            def connect(self, con, record):
1270                pass
1271
1272        class ListenFirstConnect(InstrumentingListener):
1273            def first_connect(self, con, record):
1274                pass
1275
1276        class ListenCheckOut(InstrumentingListener):
1277            def checkout(self, con, record, proxy, num):
1278                pass
1279
1280        class ListenCheckIn(InstrumentingListener):
1281            def checkin(self, con, record):
1282                pass
1283
1284        def assert_listeners(p, total, conn, fconn, cout, cin):
1285            for instance in (p, p.recreate()):
1286                self.assert_(len(instance.dispatch.connect) == conn)
1287                self.assert_(len(instance.dispatch.first_connect) == fconn)
1288                self.assert_(len(instance.dispatch.checkout) == cout)
1289                self.assert_(len(instance.dispatch.checkin) == cin)
1290
1291        p = self._queuepool_fixture()
1292        assert_listeners(p, 0, 0, 0, 0, 0)
1293
1294        with testing.expect_deprecated(
1295            *[
1296                "PoolListener.%s is deprecated." % name
1297                for name in ["connect", "first_connect", "checkout", "checkin"]
1298            ]
1299        ):
1300            p.add_listener(ListenAll())
1301        assert_listeners(p, 1, 1, 1, 1, 1)
1302
1303        with testing.expect_deprecated(
1304            *["PoolListener.%s is deprecated." % name for name in ["connect"]]
1305        ):
1306            p.add_listener(ListenConnect())
1307        assert_listeners(p, 2, 2, 1, 1, 1)
1308
1309        with testing.expect_deprecated(
1310            *[
1311                "PoolListener.%s is deprecated." % name
1312                for name in ["first_connect"]
1313            ]
1314        ):
1315            p.add_listener(ListenFirstConnect())
1316        assert_listeners(p, 3, 2, 2, 1, 1)
1317
1318        with testing.expect_deprecated(
1319            *["PoolListener.%s is deprecated." % name for name in ["checkout"]]
1320        ):
1321            p.add_listener(ListenCheckOut())
1322        assert_listeners(p, 4, 2, 2, 2, 1)
1323
1324        with testing.expect_deprecated(
1325            *["PoolListener.%s is deprecated." % name for name in ["checkin"]]
1326        ):
1327            p.add_listener(ListenCheckIn())
1328        assert_listeners(p, 5, 2, 2, 2, 2)
1329        del p
1330
1331        snoop = ListenAll()
1332
1333        with testing.expect_deprecated(
1334            *[
1335                "PoolListener.%s is deprecated." % name
1336                for name in ["connect", "first_connect", "checkout", "checkin"]
1337            ]
1338            + [
1339                "PoolListener is deprecated in favor of the PoolEvents "
1340                "listener interface.  The Pool.listeners parameter "
1341                "will be removed"
1342            ]
1343        ):
1344            p = self._queuepool_fixture(listeners=[snoop])
1345        assert_listeners(p, 1, 1, 1, 1, 1)
1346
1347        c = p.connect()
1348        snoop.assert_total(1, 1, 1, 0)
1349        cc = c.connection
1350        snoop.assert_in(cc, True, True, True, False)
1351        c.close()
1352        snoop.assert_in(cc, True, True, True, True)
1353        del c, cc
1354
1355        snoop.clear()
1356
1357        # this one depends on immediate gc
1358        c = p.connect()
1359        cc = c.connection
1360        snoop.assert_in(cc, False, False, True, False)
1361        snoop.assert_total(0, 0, 1, 0)
1362        del c, cc
1363        lazy_gc()
1364        snoop.assert_total(0, 0, 1, 1)
1365
1366        p.dispose()
1367        snoop.clear()
1368
1369        c = p.connect()
1370        c.close()
1371        c = p.connect()
1372        snoop.assert_total(1, 0, 2, 1)
1373        c.close()
1374        snoop.assert_total(1, 0, 2, 2)
1375
1376        # invalidation
1377        p.dispose()
1378        snoop.clear()
1379
1380        c = p.connect()
1381        snoop.assert_total(1, 0, 1, 0)
1382        c.invalidate()
1383        snoop.assert_total(1, 0, 1, 1)
1384        c.close()
1385        snoop.assert_total(1, 0, 1, 1)
1386        del c
1387        lazy_gc()
1388        snoop.assert_total(1, 0, 1, 1)
1389        c = p.connect()
1390        snoop.assert_total(2, 0, 2, 1)
1391        c.close()
1392        del c
1393        lazy_gc()
1394        snoop.assert_total(2, 0, 2, 2)
1395
1396        # detached
1397        p.dispose()
1398        snoop.clear()
1399
1400        c = p.connect()
1401        snoop.assert_total(1, 0, 1, 0)
1402        c.detach()
1403        snoop.assert_total(1, 0, 1, 0)
1404        c.close()
1405        del c
1406        snoop.assert_total(1, 0, 1, 0)
1407        c = p.connect()
1408        snoop.assert_total(2, 0, 2, 0)
1409        c.close()
1410        del c
1411        snoop.assert_total(2, 0, 2, 1)
1412
1413        # recreated
1414        p = p.recreate()
1415        snoop.clear()
1416
1417        c = p.connect()
1418        snoop.assert_total(1, 1, 1, 0)
1419        c.close()
1420        snoop.assert_total(1, 1, 1, 1)
1421        c = p.connect()
1422        snoop.assert_total(1, 1, 2, 1)
1423        c.close()
1424        snoop.assert_total(1, 1, 2, 2)
1425
1426    @testing.uses_deprecated(r".*Use the PoolEvents")
1427    def test_listeners_callables(self):
1428        def connect(dbapi_con, con_record):
1429            counts[0] += 1
1430
1431        def checkout(dbapi_con, con_record, con_proxy):
1432            counts[1] += 1
1433
1434        def checkin(dbapi_con, con_record):
1435            counts[2] += 1
1436
1437        i_all = dict(connect=connect, checkout=checkout, checkin=checkin)
1438        i_connect = dict(connect=connect)
1439        i_checkout = dict(checkout=checkout)
1440        i_checkin = dict(checkin=checkin)
1441
1442        for cls in (pool.QueuePool, pool.StaticPool):
1443            counts = [0, 0, 0]
1444
1445            def assert_listeners(p, total, conn, cout, cin):
1446                for instance in (p, p.recreate()):
1447                    eq_(len(instance.dispatch.connect), conn)
1448                    eq_(len(instance.dispatch.checkout), cout)
1449                    eq_(len(instance.dispatch.checkin), cin)
1450
1451            p = self._queuepool_fixture()
1452            assert_listeners(p, 0, 0, 0, 0)
1453
1454            with testing.expect_deprecated(
1455                *[
1456                    "PoolListener.%s is deprecated." % name
1457                    for name in ["connect", "checkout", "checkin"]
1458                ]
1459            ):
1460                p.add_listener(i_all)
1461            assert_listeners(p, 1, 1, 1, 1)
1462
1463            with testing.expect_deprecated(
1464                *[
1465                    "PoolListener.%s is deprecated." % name
1466                    for name in ["connect"]
1467                ]
1468            ):
1469                p.add_listener(i_connect)
1470            assert_listeners(p, 2, 1, 1, 1)
1471
1472            with testing.expect_deprecated(
1473                *[
1474                    "PoolListener.%s is deprecated." % name
1475                    for name in ["checkout"]
1476                ]
1477            ):
1478                p.add_listener(i_checkout)
1479            assert_listeners(p, 3, 1, 1, 1)
1480
1481            with testing.expect_deprecated(
1482                *[
1483                    "PoolListener.%s is deprecated." % name
1484                    for name in ["checkin"]
1485                ]
1486            ):
1487                p.add_listener(i_checkin)
1488            assert_listeners(p, 4, 1, 1, 1)
1489            del p
1490
1491            with testing.expect_deprecated(
1492                *[
1493                    "PoolListener.%s is deprecated." % name
1494                    for name in ["connect", "checkout", "checkin"]
1495                ]
1496                + [".*The Pool.listeners parameter will be removed"]
1497            ):
1498                p = self._queuepool_fixture(listeners=[i_all])
1499            assert_listeners(p, 1, 1, 1, 1)
1500
1501            c = p.connect()
1502            assert counts == [1, 1, 0]
1503            c.close()
1504            assert counts == [1, 1, 1]
1505
1506            c = p.connect()
1507            assert counts == [1, 2, 1]
1508            with testing.expect_deprecated(
1509                *[
1510                    "PoolListener.%s is deprecated." % name
1511                    for name in ["checkin"]
1512                ]
1513            ):
1514                p.add_listener(i_checkin)
1515            c.close()
1516            assert counts == [1, 2, 2]
1517
1518
1519class PoolTest(PoolTestBase):
1520    def test_manager(self):
1521        with testing.expect_deprecated(
1522            r"The pool.manage\(\) function is deprecated,"
1523        ):
1524            manager = pool.manage(MockDBAPI(), use_threadlocal=True)
1525
1526        with testing.expect_deprecated(
1527            r".*Pool.use_threadlocal parameter is deprecated"
1528        ):
1529            c1 = manager.connect("foo.db")
1530            c2 = manager.connect("foo.db")
1531            c3 = manager.connect("bar.db")
1532            c4 = manager.connect("foo.db", bar="bat")
1533            c5 = manager.connect("foo.db", bar="hoho")
1534            c6 = manager.connect("foo.db", bar="bat")
1535
1536        assert c1.cursor() is not None
1537        assert c1 is c2
1538        assert c1 is not c3
1539        assert c4 is c6
1540        assert c4 is not c5
1541
1542    def test_manager_with_key(self):
1543
1544        dbapi = MockDBAPI()
1545
1546        with testing.expect_deprecated(
1547            r"The pool.manage\(\) function is deprecated,"
1548        ):
1549            manager = pool.manage(dbapi, use_threadlocal=True)
1550
1551        with testing.expect_deprecated(
1552            r".*Pool.use_threadlocal parameter is deprecated"
1553        ):
1554            c1 = manager.connect("foo.db", sa_pool_key="a")
1555            c2 = manager.connect("foo.db", sa_pool_key="b")
1556            c3 = manager.connect("bar.db", sa_pool_key="a")
1557
1558        assert c1.cursor() is not None
1559        assert c1 is not c2
1560        assert c1 is c3
1561
1562        eq_(dbapi.connect.mock_calls, [call("foo.db"), call("foo.db")])
1563
1564    def test_bad_args(self):
1565        with testing.expect_deprecated(
1566            r"The pool.manage\(\) function is deprecated,"
1567        ):
1568            manager = pool.manage(MockDBAPI())
1569        manager.connect(None)
1570
1571    def test_non_thread_local_manager(self):
1572        with testing.expect_deprecated(
1573            r"The pool.manage\(\) function is deprecated,"
1574        ):
1575            manager = pool.manage(MockDBAPI(), use_threadlocal=False)
1576
1577        connection = manager.connect("foo.db")
1578        connection2 = manager.connect("foo.db")
1579
1580        self.assert_(connection.cursor() is not None)
1581        self.assert_(connection is not connection2)
1582
1583    def test_threadlocal_del(self):
1584        self._do_testthreadlocal(useclose=False)
1585
1586    def test_threadlocal_close(self):
1587        self._do_testthreadlocal(useclose=True)
1588
1589    def _do_testthreadlocal(self, useclose=False):
1590        dbapi = MockDBAPI()
1591
1592        with testing.expect_deprecated(
1593            r".*Pool.use_threadlocal parameter is deprecated"
1594        ):
1595            for p in (
1596                pool.QueuePool(
1597                    creator=dbapi.connect,
1598                    pool_size=3,
1599                    max_overflow=-1,
1600                    use_threadlocal=True,
1601                ),
1602                pool.SingletonThreadPool(
1603                    creator=dbapi.connect, use_threadlocal=True
1604                ),
1605            ):
1606                c1 = p.connect()
1607                c2 = p.connect()
1608                self.assert_(c1 is c2)
1609                c3 = p.unique_connection()
1610                self.assert_(c3 is not c1)
1611                if useclose:
1612                    c2.close()
1613                else:
1614                    c2 = None
1615                c2 = p.connect()
1616                self.assert_(c1 is c2)
1617                self.assert_(c3 is not c1)
1618                if useclose:
1619                    c2.close()
1620                else:
1621                    c2 = None
1622                    lazy_gc()
1623                if useclose:
1624                    c1 = p.connect()
1625                    c2 = p.connect()
1626                    c3 = p.connect()
1627                    c3.close()
1628                    c2.close()
1629                    self.assert_(c1.connection is not None)
1630                    c1.close()
1631                c1 = c2 = c3 = None
1632
1633                # extra tests with QueuePool to ensure connections get
1634                # __del__()ed when dereferenced
1635
1636                if isinstance(p, pool.QueuePool):
1637                    lazy_gc()
1638                    self.assert_(p.checkedout() == 0)
1639                    c1 = p.connect()
1640                    c2 = p.connect()
1641                    if useclose:
1642                        c2.close()
1643                        c1.close()
1644                    else:
1645                        c2 = None
1646                        c1 = None
1647                        lazy_gc()
1648                    self.assert_(p.checkedout() == 0)
1649
1650    def test_mixed_close(self):
1651        pool._refs.clear()
1652        with testing.expect_deprecated(
1653            r".*Pool.use_threadlocal parameter is deprecated"
1654        ):
1655            p = self._queuepool_fixture(
1656                pool_size=3, max_overflow=-1, use_threadlocal=True
1657            )
1658        c1 = p.connect()
1659        c2 = p.connect()
1660        assert c1 is c2
1661        c1.close()
1662        c2 = None
1663        assert p.checkedout() == 1
1664        c1 = None
1665        lazy_gc()
1666        assert p.checkedout() == 0
1667        lazy_gc()
1668        assert not pool._refs
1669
1670
1671class QueuePoolTest(PoolTestBase):
1672    def test_threadfairy(self):
1673        with testing.expect_deprecated(
1674            r".*Pool.use_threadlocal parameter is deprecated"
1675        ):
1676            p = self._queuepool_fixture(
1677                pool_size=3, max_overflow=-1, use_threadlocal=True
1678            )
1679        c1 = p.connect()
1680        c1.close()
1681        c2 = p.connect()
1682        assert c2.connection is not None
1683
1684    def test_trick_the_counter(self):
1685        """this is a "flaw" in the connection pool; since threadlocal
1686        uses a single ConnectionFairy per thread with an open/close
1687        counter, you can fool the counter into giving you a
1688        ConnectionFairy with an ambiguous counter.  i.e. its not true
1689        reference counting."""
1690
1691        with testing.expect_deprecated(
1692            r".*Pool.use_threadlocal parameter is deprecated"
1693        ):
1694            p = self._queuepool_fixture(
1695                pool_size=3, max_overflow=-1, use_threadlocal=True
1696            )
1697        c1 = p.connect()
1698        c2 = p.connect()
1699        assert c1 is c2
1700        c1.close()
1701        c2 = p.connect()
1702        c2.close()
1703        self.assert_(p.checkedout() != 0)
1704        c2.close()
1705        self.assert_(p.checkedout() == 0)
1706
1707    @testing.requires.predictable_gc
1708    def test_weakref_kaboom(self):
1709        with testing.expect_deprecated(
1710            r".*Pool.use_threadlocal parameter is deprecated"
1711        ):
1712            p = self._queuepool_fixture(
1713                pool_size=3, max_overflow=-1, use_threadlocal=True
1714            )
1715        c1 = p.connect()
1716        c2 = p.connect()
1717        c1.close()
1718        c2 = None
1719        del c1
1720        del c2
1721        gc_collect()
1722        assert p.checkedout() == 0
1723        c3 = p.connect()
1724        assert c3 is not None
1725
1726
1727class ExplicitAutoCommitDeprecatedTest(fixtures.TestBase):
1728
1729    """test the 'autocommit' flag on select() and text() objects.
1730
1731    Requires PostgreSQL so that we may define a custom function which
1732    modifies the database."""
1733
1734    __only_on__ = "postgresql"
1735
1736    @classmethod
1737    def setup_class(cls):
1738        global metadata, foo
1739        metadata = MetaData(testing.db)
1740        foo = Table(
1741            "foo",
1742            metadata,
1743            Column("id", Integer, primary_key=True),
1744            Column("data", String(100)),
1745        )
1746        metadata.create_all()
1747        testing.db.execute(
1748            "create function insert_foo(varchar) "
1749            "returns integer as 'insert into foo(data) "
1750            "values ($1);select 1;' language sql"
1751        )
1752
1753    def teardown(self):
1754        foo.delete().execute().close()
1755
1756    @classmethod
1757    def teardown_class(cls):
1758        testing.db.execute("drop function insert_foo(varchar)")
1759        metadata.drop_all()
1760
1761    def test_explicit_compiled(self):
1762        conn1 = testing.db.connect()
1763        conn2 = testing.db.connect()
1764        with testing.expect_deprecated(
1765            "The select.autocommit parameter is deprecated"
1766        ):
1767            conn1.execute(select([func.insert_foo("data1")], autocommit=True))
1768        assert conn2.execute(select([foo.c.data])).fetchall() == [("data1",)]
1769        with testing.expect_deprecated(
1770            r"The SelectBase.autocommit\(\) method is deprecated,"
1771        ):
1772            conn1.execute(select([func.insert_foo("data2")]).autocommit())
1773        assert conn2.execute(select([foo.c.data])).fetchall() == [
1774            ("data1",),
1775            ("data2",),
1776        ]
1777        conn1.close()
1778        conn2.close()
1779
1780    def test_explicit_text(self):
1781        conn1 = testing.db.connect()
1782        conn2 = testing.db.connect()
1783        with testing.expect_deprecated(
1784            "The text.autocommit parameter is deprecated"
1785        ):
1786            conn1.execute(
1787                text("select insert_foo('moredata')", autocommit=True)
1788            )
1789        assert conn2.execute(select([foo.c.data])).fetchall() == [
1790            ("moredata",)
1791        ]
1792        conn1.close()
1793        conn2.close()
1794