1# coding: utf-8
2
3from contextlib import contextmanager
4import re
5import threading
6import weakref
7
8import sqlalchemy as tsa
9from sqlalchemy import bindparam
10from sqlalchemy import create_engine
11from sqlalchemy import create_mock_engine
12from sqlalchemy import event
13from sqlalchemy import func
14from sqlalchemy import inspect
15from sqlalchemy import INT
16from sqlalchemy import Integer
17from sqlalchemy import LargeBinary
18from sqlalchemy import MetaData
19from sqlalchemy import select
20from sqlalchemy import Sequence
21from sqlalchemy import String
22from sqlalchemy import testing
23from sqlalchemy import text
24from sqlalchemy import TypeDecorator
25from sqlalchemy import util
26from sqlalchemy import VARCHAR
27from sqlalchemy.engine import default
28from sqlalchemy.engine.base import Connection
29from sqlalchemy.engine.base import Engine
30from sqlalchemy.pool import NullPool
31from sqlalchemy.pool import QueuePool
32from sqlalchemy.sql import column
33from sqlalchemy.sql import literal
34from sqlalchemy.sql.elements import literal_column
35from sqlalchemy.testing import assert_raises
36from sqlalchemy.testing import assert_raises_message
37from sqlalchemy.testing import config
38from sqlalchemy.testing import engines
39from sqlalchemy.testing import eq_
40from sqlalchemy.testing import expect_raises_message
41from sqlalchemy.testing import expect_warnings
42from sqlalchemy.testing import fixtures
43from sqlalchemy.testing import is_
44from sqlalchemy.testing import is_false
45from sqlalchemy.testing import is_not
46from sqlalchemy.testing import is_true
47from sqlalchemy.testing import mock
48from sqlalchemy.testing.assertions import expect_deprecated
49from sqlalchemy.testing.assertsql import CompiledSQL
50from sqlalchemy.testing.mock import call
51from sqlalchemy.testing.mock import Mock
52from sqlalchemy.testing.mock import patch
53from sqlalchemy.testing.schema import Column
54from sqlalchemy.testing.schema import Table
55from sqlalchemy.testing.util import gc_collect
56from sqlalchemy.testing.util import picklers
57from sqlalchemy.util import collections_abc
58
59
60class SomeException(Exception):
61    pass
62
63
64class Foo(object):
65    def __str__(self):
66        return "foo"
67
68    def __unicode__(self):
69        return util.u("fóó")
70
71
72class ExecuteTest(fixtures.TablesTest):
73    __backend__ = True
74
75    @classmethod
76    def define_tables(cls, metadata):
77        Table(
78            "users",
79            metadata,
80            Column("user_id", INT, primary_key=True, autoincrement=False),
81            Column("user_name", VARCHAR(20)),
82        )
83        Table(
84            "users_autoinc",
85            metadata,
86            Column(
87                "user_id", INT, primary_key=True, test_needs_autoincrement=True
88            ),
89            Column("user_name", VARCHAR(20)),
90        )
91
92    def test_no_params_option(self):
93        stmt = (
94            "SELECT '%'"
95            + testing.db.dialect.statement_compiler(
96                testing.db.dialect, None
97            ).default_from()
98        )
99
100        with testing.db.connect() as conn:
101            result = (
102                conn.execution_options(no_parameters=True)
103                .exec_driver_sql(stmt)
104                .scalar()
105            )
106            eq_(result, "%")
107
108    def test_raw_positional_invalid(self, connection):
109        assert_raises_message(
110            tsa.exc.ArgumentError,
111            "List argument must consist only of tuples or dictionaries",
112            connection.exec_driver_sql,
113            "insert into users (user_id, user_name) " "values (?, ?)",
114            [2, "fred"],
115        )
116
117        assert_raises_message(
118            tsa.exc.ArgumentError,
119            "List argument must consist only of tuples or dictionaries",
120            connection.exec_driver_sql,
121            "insert into users (user_id, user_name) " "values (?, ?)",
122            [[3, "ed"], [4, "horse"]],
123        )
124
125    def test_raw_named_invalid(self, connection):
126        # this is awkward b.c. this is just testing if regular Python
127        # is raising TypeError if they happened to send arguments that
128        # look like the legacy ones which also happen to conflict with
129        # the positional signature for the method.   some combinations
130        # can get through and fail differently
131        assert_raises(
132            TypeError,
133            connection.exec_driver_sql,
134            "insert into users (user_id, user_name) "
135            "values (%(id)s, %(name)s)",
136            {"id": 2, "name": "ed"},
137            {"id": 3, "name": "horse"},
138            {"id": 4, "name": "horse"},
139        )
140        assert_raises(
141            TypeError,
142            connection.exec_driver_sql,
143            "insert into users (user_id, user_name) "
144            "values (%(id)s, %(name)s)",
145            id=4,
146            name="sally",
147        )
148
149    @testing.requires.qmark_paramstyle
150    def test_raw_qmark(self, connection):
151        conn = connection
152        conn.exec_driver_sql(
153            "insert into users (user_id, user_name) " "values (?, ?)",
154            (1, "jack"),
155        )
156        conn.exec_driver_sql(
157            "insert into users (user_id, user_name) " "values (?, ?)",
158            (2, "fred"),
159        )
160        conn.exec_driver_sql(
161            "insert into users (user_id, user_name) " "values (?, ?)",
162            [(3, "ed"), (4, "horse")],
163        )
164        conn.exec_driver_sql(
165            "insert into users (user_id, user_name) " "values (?, ?)",
166            [(5, "barney"), (6, "donkey")],
167        )
168        conn.exec_driver_sql(
169            "insert into users (user_id, user_name) " "values (?, ?)",
170            (7, "sally"),
171        )
172        res = conn.exec_driver_sql("select * from users order by user_id")
173        assert res.fetchall() == [
174            (1, "jack"),
175            (2, "fred"),
176            (3, "ed"),
177            (4, "horse"),
178            (5, "barney"),
179            (6, "donkey"),
180            (7, "sally"),
181        ]
182
183        res = conn.exec_driver_sql(
184            "select * from users where user_name=?", ("jack",)
185        )
186        assert res.fetchall() == [(1, "jack")]
187
188    @testing.requires.format_paramstyle
189    def test_raw_sprintf(self, connection):
190        conn = connection
191        conn.exec_driver_sql(
192            "insert into users (user_id, user_name) " "values (%s, %s)",
193            (1, "jack"),
194        )
195        conn.exec_driver_sql(
196            "insert into users (user_id, user_name) " "values (%s, %s)",
197            [(2, "ed"), (3, "horse")],
198        )
199        conn.exec_driver_sql(
200            "insert into users (user_id, user_name) " "values (%s, %s)",
201            (4, "sally"),
202        )
203        conn.exec_driver_sql("insert into users (user_id) values (%s)", (5,))
204        res = conn.exec_driver_sql("select * from users order by user_id")
205        assert res.fetchall() == [
206            (1, "jack"),
207            (2, "ed"),
208            (3, "horse"),
209            (4, "sally"),
210            (5, None),
211        ]
212
213        res = conn.exec_driver_sql(
214            "select * from users where user_name=%s", ("jack",)
215        )
216        assert res.fetchall() == [(1, "jack")]
217
218    @testing.requires.pyformat_paramstyle
219    def test_raw_python(self, connection):
220        conn = connection
221        conn.exec_driver_sql(
222            "insert into users (user_id, user_name) "
223            "values (%(id)s, %(name)s)",
224            {"id": 1, "name": "jack"},
225        )
226        conn.exec_driver_sql(
227            "insert into users (user_id, user_name) "
228            "values (%(id)s, %(name)s)",
229            [{"id": 2, "name": "ed"}, {"id": 3, "name": "horse"}],
230        )
231        conn.exec_driver_sql(
232            "insert into users (user_id, user_name) "
233            "values (%(id)s, %(name)s)",
234            dict(id=4, name="sally"),
235        )
236        res = conn.exec_driver_sql("select * from users order by user_id")
237        assert res.fetchall() == [
238            (1, "jack"),
239            (2, "ed"),
240            (3, "horse"),
241            (4, "sally"),
242        ]
243
244    @testing.requires.named_paramstyle
245    def test_raw_named(self, connection):
246        conn = connection
247        conn.exec_driver_sql(
248            "insert into users (user_id, user_name) " "values (:id, :name)",
249            {"id": 1, "name": "jack"},
250        )
251        conn.exec_driver_sql(
252            "insert into users (user_id, user_name) " "values (:id, :name)",
253            [{"id": 2, "name": "ed"}, {"id": 3, "name": "horse"}],
254        )
255        conn.exec_driver_sql(
256            "insert into users (user_id, user_name) " "values (:id, :name)",
257            {"id": 4, "name": "sally"},
258        )
259        res = conn.exec_driver_sql("select * from users order by user_id")
260        assert res.fetchall() == [
261            (1, "jack"),
262            (2, "ed"),
263            (3, "horse"),
264            (4, "sally"),
265        ]
266
267    def test_non_dict_mapping(self, connection):
268        """ensure arbitrary Mapping works for execute()"""
269
270        class NotADict(collections_abc.Mapping):
271            def __init__(self, _data):
272                self._data = _data
273
274            def __iter__(self):
275                return iter(self._data)
276
277            def __len__(self):
278                return len(self._data)
279
280            def __getitem__(self, key):
281                return self._data[key]
282
283            def keys(self):
284                return self._data.keys()
285
286        nd = NotADict({"a": 10, "b": 15})
287        eq_(dict(nd), {"a": 10, "b": 15})
288
289        result = connection.execute(
290            select(
291                bindparam("a", type_=Integer), bindparam("b", type_=Integer)
292            ),
293            nd,
294        )
295        eq_(result.first(), (10, 15))
296
297    def test_row_works_as_mapping(self, connection):
298        """ensure the RowMapping object works as a parameter dictionary for
299        execute."""
300
301        result = connection.execute(
302            select(literal(10).label("a"), literal(15).label("b"))
303        )
304        row = result.first()
305        eq_(row, (10, 15))
306        eq_(row._mapping, {"a": 10, "b": 15})
307
308        result = connection.execute(
309            select(
310                bindparam("a", type_=Integer).label("a"),
311                bindparam("b", type_=Integer).label("b"),
312            ),
313            row._mapping,
314        )
315        row = result.first()
316        eq_(row, (10, 15))
317        eq_(row._mapping, {"a": 10, "b": 15})
318
319    def test_dialect_has_table_assertion(self):
320        with expect_raises_message(
321            tsa.exc.ArgumentError,
322            r"The argument passed to Dialect.has_table\(\) should be a",
323        ):
324            testing.db.dialect.has_table(testing.db, "some_table")
325
326    def test_exception_wrapping_dbapi(self):
327        with testing.db.connect() as conn:
328            # engine does not have exec_driver_sql
329            assert_raises_message(
330                tsa.exc.DBAPIError,
331                r"not_a_valid_statement",
332                conn.exec_driver_sql,
333                "not_a_valid_statement",
334            )
335
336    @testing.requires.sqlite
337    def test_exception_wrapping_non_dbapi_error(self):
338        e = create_engine("sqlite://")
339        e.dialect.is_disconnect = is_disconnect = Mock()
340
341        with e.connect() as c:
342            c.connection.cursor = Mock(
343                return_value=Mock(
344                    execute=Mock(
345                        side_effect=TypeError("I'm not a DBAPI error")
346                    )
347                )
348            )
349
350            assert_raises_message(
351                TypeError,
352                "I'm not a DBAPI error",
353                c.exec_driver_sql,
354                "select ",
355            )
356            eq_(is_disconnect.call_count, 0)
357
358    def test_exception_wrapping_non_standard_dbapi_error(self):
359        class DBAPIError(Exception):
360            pass
361
362        class OperationalError(DBAPIError):
363            pass
364
365        class NonStandardException(OperationalError):
366            pass
367
368        # TODO: this test is assuming too much of arbitrary dialects and would
369        # be better suited tested against a single mock dialect that does not
370        # have any special behaviors
371        with patch.object(
372            testing.db.dialect, "dbapi", Mock(Error=DBAPIError)
373        ), patch.object(
374            testing.db.dialect, "is_disconnect", lambda *arg: False
375        ), patch.object(
376            testing.db.dialect,
377            "do_execute",
378            Mock(side_effect=NonStandardException),
379        ), patch.object(
380            testing.db.dialect.execution_ctx_cls,
381            "handle_dbapi_exception",
382            Mock(),
383        ):
384            with testing.db.connect() as conn:
385                assert_raises(
386                    tsa.exc.OperationalError, conn.exec_driver_sql, "select 1"
387                )
388
389    def test_exception_wrapping_non_dbapi_statement(self):
390        class MyType(TypeDecorator):
391            impl = Integer
392            cache_ok = True
393
394            def process_bind_param(self, value, dialect):
395                raise SomeException("nope")
396
397        def _go(conn):
398            assert_raises_message(
399                tsa.exc.StatementError,
400                r"\(.*.SomeException\) " r"nope\n\[SQL\: u?SELECT 1 ",
401                conn.execute,
402                select(1).where(column("foo") == literal("bar", MyType())),
403            )
404
405        with testing.db.connect() as conn:
406            _go(conn)
407
408    def test_not_an_executable(self):
409        for obj in (
410            Table("foo", MetaData(), Column("x", Integer)),
411            Column("x", Integer),
412            tsa.and_(True),
413            tsa.and_(True).compile(),
414            column("foo"),
415            column("foo").compile(),
416            select(1).cte(),
417            # select(1).subquery(),
418            MetaData(),
419            Integer(),
420            tsa.Index(name="foo"),
421            tsa.UniqueConstraint("x"),
422        ):
423            with testing.db.connect() as conn:
424                assert_raises_message(
425                    tsa.exc.ObjectNotExecutableError,
426                    "Not an executable object",
427                    conn.execute,
428                    obj,
429                )
430
431    def test_subquery_exec_warning(self):
432        for obj in (select(1).alias(), select(1).subquery()):
433            with testing.db.connect() as conn:
434                with expect_deprecated(
435                    "Executing a subquery object is deprecated and will "
436                    "raise ObjectNotExecutableError"
437                ):
438                    eq_(conn.execute(obj).scalar(), 1)
439
440    def test_stmt_exception_bytestring_raised(self):
441        name = util.u("méil")
442        users = self.tables.users
443        with testing.db.connect() as conn:
444            assert_raises_message(
445                tsa.exc.StatementError,
446                util.u(
447                    "A value is required for bind parameter 'uname'\n"
448                    r".*SELECT users.user_name AS .m\xe9il."
449                )
450                if util.py2k
451                else util.u(
452                    "A value is required for bind parameter 'uname'\n"
453                    ".*SELECT users.user_name AS .méil."
454                ),
455                conn.execute,
456                select(users.c.user_name.label(name)).where(
457                    users.c.user_name == bindparam("uname")
458                ),
459                {"uname_incorrect": "foo"},
460            )
461
462    def test_stmt_exception_bytestring_utf8(self):
463        # uncommon case for Py3K, bytestring object passed
464        # as the error message
465        message = util.u("some message méil").encode("utf-8")
466
467        err = tsa.exc.SQLAlchemyError(message)
468        if util.py2k:
469            # string passes it through
470            eq_(str(err), message)
471
472            # unicode accessor decodes to utf-8
473            eq_(unicode(err), util.u("some message méil"))  # noqa F821
474        else:
475            eq_(str(err), util.u("some message méil"))
476
477    def test_stmt_exception_bytestring_latin1(self):
478        # uncommon case for Py3K, bytestring object passed
479        # as the error message
480        message = util.u("some message méil").encode("latin-1")
481
482        err = tsa.exc.SQLAlchemyError(message)
483        if util.py2k:
484            # string passes it through
485            eq_(str(err), message)
486
487            # unicode accessor decodes to utf-8
488            eq_(unicode(err), util.u("some message m\\xe9il"))  # noqa F821
489        else:
490            eq_(str(err), util.u("some message m\\xe9il"))
491
492    def test_stmt_exception_unicode_hook_unicode(self):
493        # uncommon case for Py2K, Unicode object passed
494        # as the error message
495        message = util.u("some message méil")
496
497        err = tsa.exc.SQLAlchemyError(message)
498        if util.py2k:
499            eq_(unicode(err), util.u("some message méil"))  # noqa F821
500        else:
501            eq_(str(err), util.u("some message méil"))
502
503    def test_stmt_exception_object_arg(self):
504        err = tsa.exc.SQLAlchemyError(Foo())
505        eq_(str(err), "foo")
506
507        if util.py2k:
508            eq_(unicode(err), util.u("fóó"))  # noqa F821
509
510    def test_stmt_exception_str_multi_args(self):
511        err = tsa.exc.SQLAlchemyError("some message", 206)
512        eq_(str(err), "('some message', 206)")
513
514    def test_stmt_exception_str_multi_args_bytestring(self):
515        message = util.u("some message méil").encode("utf-8")
516
517        err = tsa.exc.SQLAlchemyError(message, 206)
518        eq_(str(err), str((message, 206)))
519
520    def test_stmt_exception_str_multi_args_unicode(self):
521        message = util.u("some message méil")
522
523        err = tsa.exc.SQLAlchemyError(message, 206)
524        eq_(str(err), str((message, 206)))
525
526    def test_stmt_exception_pickleable_no_dbapi(self):
527        self._test_stmt_exception_pickleable(Exception("hello world"))
528
529    @testing.crashes(
530        "postgresql+psycopg2",
531        "Older versions don't support cursor pickling, newer ones do",
532    )
533    @testing.fails_on(
534        "mysql+oursql",
535        "Exception doesn't come back exactly the same from pickle",
536    )
537    @testing.fails_on(
538        "mysql+mysqlconnector",
539        "Exception doesn't come back exactly the same from pickle",
540    )
541    @testing.fails_on(
542        "oracle+cx_oracle",
543        "cx_oracle exception seems to be having " "some issue with pickling",
544    )
545    def test_stmt_exception_pickleable_plus_dbapi(self):
546        raw = testing.db.raw_connection()
547        the_orig = None
548        try:
549            try:
550                cursor = raw.cursor()
551                cursor.execute("SELECTINCORRECT")
552            except testing.db.dialect.dbapi.Error as orig:
553                # py3k has "orig" in local scope...
554                the_orig = orig
555        finally:
556            raw.close()
557        self._test_stmt_exception_pickleable(the_orig)
558
559    def _test_stmt_exception_pickleable(self, orig):
560        for sa_exc in (
561            tsa.exc.StatementError(
562                "some error",
563                "select * from table",
564                {"foo": "bar"},
565                orig,
566                False,
567            ),
568            tsa.exc.InterfaceError(
569                "select * from table", {"foo": "bar"}, orig, True
570            ),
571            tsa.exc.NoReferencedTableError("message", "tname"),
572            tsa.exc.NoReferencedColumnError("message", "tname", "cname"),
573            tsa.exc.CircularDependencyError(
574                "some message", [1, 2, 3], [(1, 2), (3, 4)]
575            ),
576        ):
577            for loads, dumps in picklers():
578                repickled = loads(dumps(sa_exc))
579                eq_(repickled.args[0], sa_exc.args[0])
580                if isinstance(sa_exc, tsa.exc.StatementError):
581                    eq_(repickled.params, {"foo": "bar"})
582                    eq_(repickled.statement, sa_exc.statement)
583                    if hasattr(sa_exc, "connection_invalidated"):
584                        eq_(
585                            repickled.connection_invalidated,
586                            sa_exc.connection_invalidated,
587                        )
588                    eq_(repickled.orig.args[0], orig.args[0])
589
590    def test_dont_wrap_mixin(self):
591        class MyException(Exception, tsa.exc.DontWrapMixin):
592            pass
593
594        class MyType(TypeDecorator):
595            impl = Integer
596            cache_ok = True
597
598            def process_bind_param(self, value, dialect):
599                raise MyException("nope")
600
601        def _go(conn):
602            assert_raises_message(
603                MyException,
604                "nope",
605                conn.execute,
606                select(1).where(column("foo") == literal("bar", MyType())),
607            )
608
609        conn = testing.db.connect()
610        try:
611            _go(conn)
612        finally:
613            conn.close()
614
615    def test_empty_insert(self, connection):
616        """test that execute() interprets [] as a list with no params"""
617        users_autoinc = self.tables.users_autoinc
618
619        connection.execute(
620            users_autoinc.insert().values(user_name=bindparam("name", None)),
621            [],
622        )
623        eq_(connection.execute(users_autoinc.select()).fetchall(), [(1, None)])
624
625    @testing.only_on("sqlite")
626    def test_execute_compiled_favors_compiled_paramstyle(self):
627        users = self.tables.users
628
629        with patch.object(testing.db.dialect, "do_execute") as do_exec:
630            stmt = users.update().values(user_id=1, user_name="foo")
631
632            d1 = default.DefaultDialect(paramstyle="format")
633            d2 = default.DefaultDialect(paramstyle="pyformat")
634
635            with testing.db.begin() as conn:
636                conn.execute(stmt.compile(dialect=d1))
637                conn.execute(stmt.compile(dialect=d2))
638
639            eq_(
640                do_exec.mock_calls,
641                [
642                    call(
643                        mock.ANY,
644                        "UPDATE users SET user_id=%s, user_name=%s",
645                        (1, "foo"),
646                        mock.ANY,
647                    ),
648                    call(
649                        mock.ANY,
650                        "UPDATE users SET user_id=%(user_id)s, "
651                        "user_name=%(user_name)s",
652                        {"user_name": "foo", "user_id": 1},
653                        mock.ANY,
654                    ),
655                ],
656            )
657
658    @testing.requires.ad_hoc_engines
659    def test_engine_level_options(self):
660        eng = engines.testing_engine(
661            options={"execution_options": {"foo": "bar"}}
662        )
663        with eng.connect() as conn:
664            eq_(conn._execution_options["foo"], "bar")
665            eq_(
666                conn.execution_options(bat="hoho")._execution_options["foo"],
667                "bar",
668            )
669            eq_(
670                conn.execution_options(bat="hoho")._execution_options["bat"],
671                "hoho",
672            )
673            eq_(
674                conn.execution_options(foo="hoho")._execution_options["foo"],
675                "hoho",
676            )
677            eng.update_execution_options(foo="hoho")
678            conn = eng.connect()
679            eq_(conn._execution_options["foo"], "hoho")
680
681    @testing.requires.ad_hoc_engines
682    def test_generative_engine_execution_options(self):
683        eng = engines.testing_engine(
684            options={"execution_options": {"base": "x1"}}
685        )
686
687        is_(eng.engine, eng)
688
689        eng1 = eng.execution_options(foo="b1")
690        is_(eng1.engine, eng1)
691        eng2 = eng.execution_options(foo="b2")
692        eng1a = eng1.execution_options(bar="a1")
693        eng2a = eng2.execution_options(foo="b3", bar="a2")
694        is_(eng2a.engine, eng2a)
695
696        eq_(eng._execution_options, {"base": "x1"})
697        eq_(eng1._execution_options, {"base": "x1", "foo": "b1"})
698        eq_(eng2._execution_options, {"base": "x1", "foo": "b2"})
699        eq_(eng1a._execution_options, {"base": "x1", "foo": "b1", "bar": "a1"})
700        eq_(eng2a._execution_options, {"base": "x1", "foo": "b3", "bar": "a2"})
701        is_(eng1a.pool, eng.pool)
702
703        # test pool is shared
704        eng2.dispose()
705        is_(eng1a.pool, eng2.pool)
706        is_(eng.pool, eng2.pool)
707
708    @testing.requires.ad_hoc_engines
709    def test_autocommit_option_no_issue_first_connect(self):
710        eng = create_engine(testing.db.url)
711        eng.update_execution_options(autocommit=True)
712        conn = eng.connect()
713        eq_(conn._execution_options, {"autocommit": True})
714        conn.close()
715
716    def test_initialize_rollback(self):
717        """test a rollback happens during first connect"""
718        eng = create_engine(testing.db.url)
719        with patch.object(eng.dialect, "do_rollback") as do_rollback:
720            assert do_rollback.call_count == 0
721            connection = eng.connect()
722            assert do_rollback.call_count == 1
723        connection.close()
724
725    @testing.requires.ad_hoc_engines
726    def test_dialect_init_uses_options(self):
727        eng = create_engine(testing.db.url)
728
729        def my_init(connection):
730            connection.execution_options(foo="bar").execute(select(1))
731
732        with patch.object(eng.dialect, "initialize", my_init):
733            conn = eng.connect()
734            eq_(conn._execution_options, {})
735            conn.close()
736
737    @testing.requires.ad_hoc_engines
738    def test_generative_engine_event_dispatch_hasevents(self):
739        def l1(*arg, **kw):
740            pass
741
742        eng = create_engine(testing.db.url)
743        assert not eng._has_events
744        event.listen(eng, "before_execute", l1)
745        eng2 = eng.execution_options(foo="bar")
746        assert eng2._has_events
747
748    def test_works_after_dispose(self):
749        eng = create_engine(testing.db.url)
750        for i in range(3):
751            with eng.connect() as conn:
752                eq_(conn.scalar(select(1)), 1)
753            eng.dispose()
754
755    def test_works_after_dispose_testing_engine(self):
756        eng = engines.testing_engine()
757        for i in range(3):
758            with eng.connect() as conn:
759                eq_(conn.scalar(select(1)), 1)
760            eng.dispose()
761
762    def test_scalar(self, connection):
763        conn = connection
764        users = self.tables.users
765        conn.execute(
766            users.insert(),
767            [
768                {"user_id": 1, "user_name": "sandy"},
769                {"user_id": 2, "user_name": "spongebob"},
770            ],
771        )
772        res = conn.scalar(select(users.c.user_name).order_by(users.c.user_id))
773        eq_(res, "sandy")
774
775    def test_scalars(self, connection):
776        conn = connection
777        users = self.tables.users
778        conn.execute(
779            users.insert(),
780            [
781                {"user_id": 1, "user_name": "sandy"},
782                {"user_id": 2, "user_name": "spongebob"},
783            ],
784        )
785        res = conn.scalars(select(users.c.user_name).order_by(users.c.user_id))
786        eq_(res.all(), ["sandy", "spongebob"])
787
788
789class UnicodeReturnsTest(fixtures.TestBase):
790    @testing.requires.python3
791    def test_unicode_test_not_in_python3(self):
792        eng = engines.testing_engine()
793        eng.dialect.returns_unicode_strings = String.RETURNS_UNKNOWN
794
795        assert_raises_message(
796            tsa.exc.InvalidRequestError,
797            "RETURNS_UNKNOWN is unsupported in Python 3",
798            eng.connect,
799        )
800
801    @testing.requires.python2
802    def test_unicode_test_fails_warning(self):
803        class MockCursor(engines.DBAPIProxyCursor):
804            def execute(self, stmt, params=None, **kw):
805                if "test unicode returns" in stmt:
806                    raise self.engine.dialect.dbapi.DatabaseError("boom")
807                else:
808                    return super(MockCursor, self).execute(stmt, params, **kw)
809
810        eng = engines.proxying_engine(cursor_cls=MockCursor)
811        with testing.expect_warnings(
812            "Exception attempting to detect unicode returns"
813        ):
814            eng.connect()
815
816        # because plain varchar passed, we don't know the correct answer
817        eq_(eng.dialect.returns_unicode_strings, String.RETURNS_CONDITIONAL)
818        eng.dispose()
819
820
821class ConvenienceExecuteTest(fixtures.TablesTest):
822    __backend__ = True
823
824    @classmethod
825    def define_tables(cls, metadata):
826        cls.table = Table(
827            "exec_test",
828            metadata,
829            Column("a", Integer),
830            Column("b", Integer),
831            test_needs_acid=True,
832        )
833
834    def _trans_fn(self, is_transaction=False):
835        def go(conn, x, value=None):
836            if is_transaction:
837                conn = conn.connection
838            conn.execute(self.table.insert().values(a=x, b=value))
839
840        return go
841
842    def _trans_rollback_fn(self, is_transaction=False):
843        def go(conn, x, value=None):
844            if is_transaction:
845                conn = conn.connection
846            conn.execute(self.table.insert().values(a=x, b=value))
847            raise SomeException("breakage")
848
849        return go
850
851    def _assert_no_data(self):
852        with testing.db.connect() as conn:
853            eq_(
854                conn.scalar(select(func.count("*")).select_from(self.table)),
855                0,
856            )
857
858    def _assert_fn(self, x, value=None):
859        with testing.db.connect() as conn:
860            eq_(conn.execute(self.table.select()).fetchall(), [(x, value)])
861
862    def test_transaction_engine_ctx_commit(self):
863        fn = self._trans_fn()
864        ctx = testing.db.begin()
865        testing.run_as_contextmanager(ctx, fn, 5, value=8)
866        self._assert_fn(5, value=8)
867
868    def test_transaction_engine_ctx_begin_fails_dont_enter_enter(self):
869        """test #7272"""
870        engine = engines.testing_engine()
871
872        mock_connection = Mock(
873            return_value=Mock(begin=Mock(side_effect=Exception("boom")))
874        )
875        with mock.patch.object(engine, "_connection_cls", mock_connection):
876            if testing.requires.legacy_engine.enabled:
877                with expect_raises_message(Exception, "boom"):
878                    engine.begin()
879            else:
880                # context manager isn't entered, doesn't actually call
881                # connect() or connection.begin()
882                engine.begin()
883
884        if testing.requires.legacy_engine.enabled:
885            eq_(mock_connection.return_value.close.mock_calls, [call()])
886        else:
887            eq_(mock_connection.return_value.close.mock_calls, [])
888
889    def test_transaction_engine_ctx_begin_fails_include_enter(self):
890        """test #7272"""
891        engine = engines.testing_engine()
892
893        close_mock = Mock()
894        with mock.patch.object(
895            engine._connection_cls,
896            "begin",
897            Mock(side_effect=Exception("boom")),
898        ), mock.patch.object(engine._connection_cls, "close", close_mock):
899            with expect_raises_message(Exception, "boom"):
900                with engine.begin():
901                    pass
902
903        eq_(close_mock.mock_calls, [call()])
904
905    def test_transaction_engine_ctx_rollback(self):
906        fn = self._trans_rollback_fn()
907        ctx = testing.db.begin()
908        assert_raises_message(
909            Exception,
910            "breakage",
911            testing.run_as_contextmanager,
912            ctx,
913            fn,
914            5,
915            value=8,
916        )
917        self._assert_no_data()
918
919    def test_transaction_connection_ctx_commit(self):
920        fn = self._trans_fn(True)
921        with testing.db.connect() as conn:
922            ctx = conn.begin()
923            testing.run_as_contextmanager(ctx, fn, 5, value=8)
924            self._assert_fn(5, value=8)
925
926    def test_transaction_connection_ctx_rollback(self):
927        fn = self._trans_rollback_fn(True)
928        with testing.db.connect() as conn:
929            ctx = conn.begin()
930            assert_raises_message(
931                Exception,
932                "breakage",
933                testing.run_as_contextmanager,
934                ctx,
935                fn,
936                5,
937                value=8,
938            )
939            self._assert_no_data()
940
941    def test_connection_as_ctx(self):
942        fn = self._trans_fn()
943        with testing.db.begin() as conn:
944            fn(conn, 5, value=8)
945        self._assert_fn(5, value=8)
946
947    @testing.fails_on("mysql+oursql", "oursql bug ?  getting wrong rowcount")
948    @testing.requires.legacy_engine
949    def test_connect_as_ctx_noautocommit(self):
950        fn = self._trans_fn()
951        self._assert_no_data()
952
953        with testing.db.connect() as conn:
954            ctx = conn.execution_options(autocommit=False)
955            testing.run_as_contextmanager(ctx, fn, 5, value=8)
956            # autocommit is off
957            self._assert_no_data()
958
959
960class FutureConvenienceExecuteTest(
961    fixtures.FutureEngineMixin, ConvenienceExecuteTest
962):
963    __backend__ = True
964
965
966class CompiledCacheTest(fixtures.TestBase):
967    __backend__ = True
968
969    def test_cache(self, connection, metadata):
970        users = Table(
971            "users",
972            metadata,
973            Column(
974                "user_id", INT, primary_key=True, test_needs_autoincrement=True
975            ),
976            Column("user_name", VARCHAR(20)),
977            Column("extra_data", VARCHAR(20)),
978        )
979        users.create(connection)
980
981        conn = connection
982        cache = {}
983        cached_conn = conn.execution_options(compiled_cache=cache)
984
985        ins = users.insert()
986        with patch.object(
987            ins, "_compiler", Mock(side_effect=ins._compiler)
988        ) as compile_mock:
989            cached_conn.execute(ins, {"user_name": "u1"})
990            cached_conn.execute(ins, {"user_name": "u2"})
991            cached_conn.execute(ins, {"user_name": "u3"})
992        eq_(compile_mock.call_count, 1)
993        assert len(cache) == 1
994        eq_(conn.exec_driver_sql("select count(*) from users").scalar(), 3)
995
996    @testing.only_on(
997        ["sqlite", "mysql", "postgresql"],
998        "uses blob value that is problematic for some DBAPIs",
999    )
1000    def test_cache_noleak_on_statement_values(self, metadata, connection):
1001        # This is a non regression test for an object reference leak caused
1002        # by the compiled_cache.
1003
1004        photo = Table(
1005            "photo",
1006            metadata,
1007            Column(
1008                "id", Integer, primary_key=True, test_needs_autoincrement=True
1009            ),
1010            Column("photo_blob", LargeBinary()),
1011        )
1012        metadata.create_all(connection)
1013
1014        cache = {}
1015        cached_conn = connection.execution_options(compiled_cache=cache)
1016
1017        class PhotoBlob(bytearray):
1018            pass
1019
1020        blob = PhotoBlob(100)
1021        ref_blob = weakref.ref(blob)
1022
1023        ins = photo.insert()
1024        with patch.object(
1025            ins, "_compiler", Mock(side_effect=ins._compiler)
1026        ) as compile_mock:
1027            cached_conn.execute(ins, {"photo_blob": blob})
1028        eq_(compile_mock.call_count, 1)
1029        eq_(len(cache), 1)
1030        eq_(
1031            connection.exec_driver_sql("select count(*) from photo").scalar(),
1032            1,
1033        )
1034
1035        del blob
1036
1037        gc_collect()
1038
1039        # The compiled statement cache should not hold any reference to the
1040        # the statement values (only the keys).
1041        eq_(ref_blob(), None)
1042
1043    def test_keys_independent_of_ordering(self, connection, metadata):
1044        users = Table(
1045            "users",
1046            metadata,
1047            Column(
1048                "user_id", INT, primary_key=True, test_needs_autoincrement=True
1049            ),
1050            Column("user_name", VARCHAR(20)),
1051            Column("extra_data", VARCHAR(20)),
1052        )
1053        users.create(connection)
1054
1055        connection.execute(
1056            users.insert(),
1057            {"user_id": 1, "user_name": "u1", "extra_data": "e1"},
1058        )
1059        cache = {}
1060        cached_conn = connection.execution_options(compiled_cache=cache)
1061
1062        upd = users.update().where(users.c.user_id == bindparam("b_user_id"))
1063
1064        with patch.object(
1065            upd, "_compiler", Mock(side_effect=upd._compiler)
1066        ) as compile_mock:
1067            cached_conn.execute(
1068                upd,
1069                util.OrderedDict(
1070                    [
1071                        ("b_user_id", 1),
1072                        ("user_name", "u2"),
1073                        ("extra_data", "e2"),
1074                    ]
1075                ),
1076            )
1077            cached_conn.execute(
1078                upd,
1079                util.OrderedDict(
1080                    [
1081                        ("b_user_id", 1),
1082                        ("extra_data", "e3"),
1083                        ("user_name", "u3"),
1084                    ]
1085                ),
1086            )
1087            cached_conn.execute(
1088                upd,
1089                util.OrderedDict(
1090                    [
1091                        ("extra_data", "e4"),
1092                        ("user_name", "u4"),
1093                        ("b_user_id", 1),
1094                    ]
1095                ),
1096            )
1097        eq_(compile_mock.call_count, 1)
1098        eq_(len(cache), 1)
1099
1100    @testing.requires.schemas
1101    def test_schema_translate_in_key(self, metadata, connection):
1102        Table("x", metadata, Column("q", Integer))
1103        Table("x", metadata, Column("q", Integer), schema=config.test_schema)
1104        metadata.create_all(connection)
1105
1106        m = MetaData()
1107        t1 = Table("x", m, Column("q", Integer))
1108        ins = t1.insert()
1109        stmt = select(t1.c.q)
1110
1111        cache = {}
1112
1113        conn = connection.execution_options(compiled_cache=cache)
1114        conn.execute(ins, {"q": 1})
1115        eq_(conn.scalar(stmt), 1)
1116
1117        conn = connection.execution_options(
1118            compiled_cache=cache,
1119            schema_translate_map={None: config.test_schema},
1120        )
1121        conn.execute(ins, {"q": 2})
1122        eq_(conn.scalar(stmt), 2)
1123
1124        conn = connection.execution_options(
1125            compiled_cache=cache,
1126            schema_translate_map={None: None},
1127        )
1128        # should use default schema again even though statement
1129        # was compiled with test_schema in the map
1130        eq_(conn.scalar(stmt), 1)
1131
1132        conn = connection.execution_options(
1133            compiled_cache=cache,
1134        )
1135        eq_(conn.scalar(stmt), 1)
1136
1137
1138class MockStrategyTest(fixtures.TestBase):
1139    def _engine_fixture(self):
1140        buf = util.StringIO()
1141
1142        def dump(sql, *multiparams, **params):
1143            buf.write(util.text_type(sql.compile(dialect=engine.dialect)))
1144
1145        engine = create_mock_engine("postgresql://", executor=dump)
1146        return engine, buf
1147
1148    def test_sequence_not_duped(self):
1149        engine, buf = self._engine_fixture()
1150        metadata = MetaData()
1151        t = Table(
1152            "testtable",
1153            metadata,
1154            Column(
1155                "pk",
1156                Integer,
1157                Sequence("testtable_pk_seq"),
1158                primary_key=True,
1159            ),
1160        )
1161
1162        t.create(engine)
1163        t.drop(engine)
1164
1165        eq_(re.findall(r"CREATE (\w+)", buf.getvalue()), ["SEQUENCE", "TABLE"])
1166
1167        eq_(re.findall(r"DROP (\w+)", buf.getvalue()), ["TABLE", "SEQUENCE"])
1168
1169
1170class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
1171    __requires__ = ("schemas",)
1172    __backend__ = True
1173
1174    @testing.fixture
1175    def plain_tables(self, metadata):
1176        t1 = Table(
1177            "t1", metadata, Column("x", Integer), schema=config.test_schema
1178        )
1179        t2 = Table(
1180            "t2", metadata, Column("x", Integer), schema=config.test_schema
1181        )
1182        t3 = Table("t3", metadata, Column("x", Integer), schema=None)
1183
1184        return t1, t2, t3
1185
1186    def test_create_table(self, plain_tables, connection):
1187        map_ = {
1188            None: config.test_schema,
1189            "foo": config.test_schema,
1190            "bar": None,
1191        }
1192
1193        metadata = MetaData()
1194        t1 = Table("t1", metadata, Column("x", Integer))
1195        t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
1196        t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
1197
1198        with self.sql_execution_asserter(connection) as asserter:
1199            conn = connection.execution_options(schema_translate_map=map_)
1200
1201            t1.create(conn)
1202            t2.create(conn)
1203            t3.create(conn)
1204
1205            t3.drop(conn)
1206            t2.drop(conn)
1207            t1.drop(conn)
1208
1209        asserter.assert_(
1210            CompiledSQL("CREATE TABLE __[SCHEMA__none].t1 (x INTEGER)"),
1211            CompiledSQL("CREATE TABLE __[SCHEMA_foo].t2 (x INTEGER)"),
1212            CompiledSQL("CREATE TABLE __[SCHEMA_bar].t3 (x INTEGER)"),
1213            CompiledSQL("DROP TABLE __[SCHEMA_bar].t3"),
1214            CompiledSQL("DROP TABLE __[SCHEMA_foo].t2"),
1215            CompiledSQL("DROP TABLE __[SCHEMA__none].t1"),
1216        )
1217
1218    def test_ddl_hastable(self, plain_tables, connection):
1219
1220        map_ = {
1221            None: config.test_schema,
1222            "foo": config.test_schema,
1223            "bar": None,
1224        }
1225
1226        metadata = MetaData()
1227        Table("t1", metadata, Column("x", Integer))
1228        Table("t2", metadata, Column("x", Integer), schema="foo")
1229        Table("t3", metadata, Column("x", Integer), schema="bar")
1230
1231        conn = connection.execution_options(schema_translate_map=map_)
1232        metadata.create_all(conn)
1233
1234        insp = inspect(connection)
1235        is_true(insp.has_table("t1", schema=config.test_schema))
1236        is_true(insp.has_table("t2", schema=config.test_schema))
1237        is_true(insp.has_table("t3", schema=None))
1238
1239        conn = connection.execution_options(schema_translate_map=map_)
1240
1241        # if this test fails, the tables won't get dropped.  so need a
1242        # more robust fixture for this
1243        metadata.drop_all(conn)
1244
1245        insp = inspect(connection)
1246        is_false(insp.has_table("t1", schema=config.test_schema))
1247        is_false(insp.has_table("t2", schema=config.test_schema))
1248        is_false(insp.has_table("t3", schema=None))
1249
1250    def test_option_on_execute(self, plain_tables, connection):
1251        # provided by metadata fixture provided by plain_tables fixture
1252        self.metadata.create_all(connection)
1253
1254        map_ = {
1255            None: config.test_schema,
1256            "foo": config.test_schema,
1257            "bar": None,
1258        }
1259
1260        metadata = MetaData()
1261        t1 = Table("t1", metadata, Column("x", Integer))
1262        t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
1263        t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
1264
1265        with self.sql_execution_asserter(connection) as asserter:
1266            conn = connection
1267            execution_options = {"schema_translate_map": map_}
1268            conn._execute_20(
1269                t1.insert(), {"x": 1}, execution_options=execution_options
1270            )
1271            conn._execute_20(
1272                t2.insert(), {"x": 1}, execution_options=execution_options
1273            )
1274            conn._execute_20(
1275                t3.insert(), {"x": 1}, execution_options=execution_options
1276            )
1277
1278            conn._execute_20(
1279                t1.update().values(x=1).where(t1.c.x == 1),
1280                execution_options=execution_options,
1281            )
1282            conn._execute_20(
1283                t2.update().values(x=2).where(t2.c.x == 1),
1284                execution_options=execution_options,
1285            )
1286            conn._execute_20(
1287                t3.update().values(x=3).where(t3.c.x == 1),
1288                execution_options=execution_options,
1289            )
1290
1291            eq_(
1292                conn._execute_20(
1293                    select(t1.c.x), execution_options=execution_options
1294                ).scalar(),
1295                1,
1296            )
1297            eq_(
1298                conn._execute_20(
1299                    select(t2.c.x), execution_options=execution_options
1300                ).scalar(),
1301                2,
1302            )
1303            eq_(
1304                conn._execute_20(
1305                    select(t3.c.x), execution_options=execution_options
1306                ).scalar(),
1307                3,
1308            )
1309
1310            conn._execute_20(t1.delete(), execution_options=execution_options)
1311            conn._execute_20(t2.delete(), execution_options=execution_options)
1312            conn._execute_20(t3.delete(), execution_options=execution_options)
1313
1314        asserter.assert_(
1315            CompiledSQL("INSERT INTO __[SCHEMA__none].t1 (x) VALUES (:x)"),
1316            CompiledSQL("INSERT INTO __[SCHEMA_foo].t2 (x) VALUES (:x)"),
1317            CompiledSQL("INSERT INTO __[SCHEMA_bar].t3 (x) VALUES (:x)"),
1318            CompiledSQL(
1319                "UPDATE __[SCHEMA__none].t1 SET x=:x WHERE "
1320                "__[SCHEMA__none].t1.x = :x_1"
1321            ),
1322            CompiledSQL(
1323                "UPDATE __[SCHEMA_foo].t2 SET x=:x WHERE "
1324                "__[SCHEMA_foo].t2.x = :x_1"
1325            ),
1326            CompiledSQL(
1327                "UPDATE __[SCHEMA_bar].t3 SET x=:x WHERE "
1328                "__[SCHEMA_bar].t3.x = :x_1"
1329            ),
1330            CompiledSQL(
1331                "SELECT __[SCHEMA__none].t1.x FROM __[SCHEMA__none].t1"
1332            ),
1333            CompiledSQL("SELECT __[SCHEMA_foo].t2.x FROM __[SCHEMA_foo].t2"),
1334            CompiledSQL("SELECT __[SCHEMA_bar].t3.x FROM __[SCHEMA_bar].t3"),
1335            CompiledSQL("DELETE FROM __[SCHEMA__none].t1"),
1336            CompiledSQL("DELETE FROM __[SCHEMA_foo].t2"),
1337            CompiledSQL("DELETE FROM __[SCHEMA_bar].t3"),
1338        )
1339
1340    def test_crud(self, plain_tables, connection):
1341        # provided by metadata fixture provided by plain_tables fixture
1342        self.metadata.create_all(connection)
1343
1344        map_ = {
1345            None: config.test_schema,
1346            "foo": config.test_schema,
1347            "bar": None,
1348        }
1349
1350        metadata = MetaData()
1351        t1 = Table("t1", metadata, Column("x", Integer))
1352        t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
1353        t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
1354
1355        with self.sql_execution_asserter(connection) as asserter:
1356            conn = connection.execution_options(schema_translate_map=map_)
1357
1358            conn.execute(t1.insert(), {"x": 1})
1359            conn.execute(t2.insert(), {"x": 1})
1360            conn.execute(t3.insert(), {"x": 1})
1361
1362            conn.execute(t1.update().values(x=1).where(t1.c.x == 1))
1363            conn.execute(t2.update().values(x=2).where(t2.c.x == 1))
1364            conn.execute(t3.update().values(x=3).where(t3.c.x == 1))
1365
1366            eq_(conn.scalar(select(t1.c.x)), 1)
1367            eq_(conn.scalar(select(t2.c.x)), 2)
1368            eq_(conn.scalar(select(t3.c.x)), 3)
1369
1370            conn.execute(t1.delete())
1371            conn.execute(t2.delete())
1372            conn.execute(t3.delete())
1373
1374        asserter.assert_(
1375            CompiledSQL("INSERT INTO __[SCHEMA__none].t1 (x) VALUES (:x)"),
1376            CompiledSQL("INSERT INTO __[SCHEMA_foo].t2 (x) VALUES (:x)"),
1377            CompiledSQL("INSERT INTO __[SCHEMA_bar].t3 (x) VALUES (:x)"),
1378            CompiledSQL(
1379                "UPDATE __[SCHEMA__none].t1 SET x=:x WHERE "
1380                "__[SCHEMA__none].t1.x = :x_1"
1381            ),
1382            CompiledSQL(
1383                "UPDATE __[SCHEMA_foo].t2 SET x=:x WHERE "
1384                "__[SCHEMA_foo].t2.x = :x_1"
1385            ),
1386            CompiledSQL(
1387                "UPDATE __[SCHEMA_bar].t3 SET x=:x WHERE "
1388                "__[SCHEMA_bar].t3.x = :x_1"
1389            ),
1390            CompiledSQL(
1391                "SELECT __[SCHEMA__none].t1.x FROM __[SCHEMA__none].t1"
1392            ),
1393            CompiledSQL("SELECT __[SCHEMA_foo].t2.x FROM __[SCHEMA_foo].t2"),
1394            CompiledSQL("SELECT __[SCHEMA_bar].t3.x FROM __[SCHEMA_bar].t3"),
1395            CompiledSQL("DELETE FROM __[SCHEMA__none].t1"),
1396            CompiledSQL("DELETE FROM __[SCHEMA_foo].t2"),
1397            CompiledSQL("DELETE FROM __[SCHEMA_bar].t3"),
1398        )
1399
1400    def test_via_engine(self, plain_tables, metadata):
1401
1402        with config.db.begin() as connection:
1403            metadata.create_all(connection)
1404
1405        map_ = {
1406            None: config.test_schema,
1407            "foo": config.test_schema,
1408            "bar": None,
1409        }
1410
1411        metadata = MetaData()
1412        t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
1413
1414        with self.sql_execution_asserter(config.db) as asserter:
1415            eng = config.db.execution_options(schema_translate_map=map_)
1416            with eng.connect() as conn:
1417                conn.execute(select(t2.c.x))
1418        asserter.assert_(
1419            CompiledSQL("SELECT __[SCHEMA_foo].t2.x FROM __[SCHEMA_foo].t2")
1420        )
1421
1422
1423class ExecutionOptionsTest(fixtures.TestBase):
1424    def test_dialect_conn_options(self, testing_engine):
1425        engine = testing_engine("sqlite://", options=dict(_initialize=False))
1426        engine.dialect = Mock()
1427        with engine.connect() as conn:
1428            c2 = conn.execution_options(foo="bar")
1429            eq_(
1430                engine.dialect.set_connection_execution_options.mock_calls,
1431                [call(c2, {"foo": "bar"})],
1432            )
1433
1434    def test_dialect_engine_options(self, testing_engine):
1435        engine = testing_engine("sqlite://")
1436        engine.dialect = Mock()
1437        e2 = engine.execution_options(foo="bar")
1438        eq_(
1439            engine.dialect.set_engine_execution_options.mock_calls,
1440            [call(e2, {"foo": "bar"})],
1441        )
1442
1443    def test_dialect_engine_construction_options(self):
1444        dialect = Mock()
1445        engine = Engine(
1446            Mock(), dialect, Mock(), execution_options={"foo": "bar"}
1447        )
1448        eq_(
1449            dialect.set_engine_execution_options.mock_calls,
1450            [call(engine, {"foo": "bar"})],
1451        )
1452
1453    def test_propagate_engine_to_connection(self, testing_engine):
1454        engine = testing_engine(
1455            "sqlite://", options=dict(execution_options={"foo": "bar"})
1456        )
1457        with engine.connect() as conn:
1458            eq_(conn._execution_options, {"foo": "bar"})
1459
1460    def test_propagate_option_engine_to_connection(self, testing_engine):
1461        e1 = testing_engine(
1462            "sqlite://", options=dict(execution_options={"foo": "bar"})
1463        )
1464        e2 = e1.execution_options(bat="hoho")
1465        c1 = e1.connect()
1466        c2 = e2.connect()
1467        eq_(c1._execution_options, {"foo": "bar"})
1468        eq_(c2._execution_options, {"foo": "bar", "bat": "hoho"})
1469
1470        c1.close()
1471        c2.close()
1472
1473    def test_get_engine_execution_options(self, testing_engine):
1474        engine = testing_engine("sqlite://")
1475        engine.dialect = Mock()
1476        e2 = engine.execution_options(foo="bar")
1477
1478        eq_(e2.get_execution_options(), {"foo": "bar"})
1479
1480    def test_get_connection_execution_options(self, testing_engine):
1481        engine = testing_engine("sqlite://", options=dict(_initialize=False))
1482        engine.dialect = Mock()
1483        with engine.connect() as conn:
1484            c = conn.execution_options(foo="bar")
1485
1486            eq_(c.get_execution_options(), {"foo": "bar"})
1487
1488
1489class EngineEventsTest(fixtures.TestBase):
1490    __requires__ = ("ad_hoc_engines",)
1491    __backend__ = True
1492
1493    def teardown_test(self):
1494        Engine.dispatch._clear()
1495        Engine._has_events = False
1496
1497    def _assert_stmts(self, expected, received):
1498        list(received)
1499
1500        for stmt, params, posn in expected:
1501            if not received:
1502                assert False, "Nothing available for stmt: %s" % stmt
1503            while received:
1504                teststmt, testparams, testmultiparams = received.pop(0)
1505                teststmt = (
1506                    re.compile(r"[\n\t ]+", re.M).sub(" ", teststmt).strip()
1507                )
1508                if teststmt.startswith(stmt) and (
1509                    testparams == params or testparams == posn
1510                ):
1511                    break
1512
1513    def test_per_engine_independence(self, testing_engine):
1514        e1 = testing_engine(config.db_url)
1515        e2 = testing_engine(config.db_url)
1516
1517        canary = Mock()
1518        event.listen(e1, "before_execute", canary)
1519        s1 = select(1)
1520        s2 = select(2)
1521
1522        with e1.connect() as conn:
1523            conn.execute(s1)
1524
1525        with e2.connect() as conn:
1526            conn.execute(s2)
1527        eq_([arg[1][1] for arg in canary.mock_calls], [s1])
1528        event.listen(e2, "before_execute", canary)
1529
1530        with e1.connect() as conn:
1531            conn.execute(s1)
1532
1533        with e2.connect() as conn:
1534            conn.execute(s2)
1535        eq_([arg[1][1] for arg in canary.mock_calls], [s1, s1, s2])
1536
1537    def test_per_engine_plus_global(self, testing_engine):
1538        canary = Mock()
1539        event.listen(Engine, "before_execute", canary.be1)
1540        e1 = testing_engine(config.db_url)
1541        e2 = testing_engine(config.db_url)
1542
1543        event.listen(e1, "before_execute", canary.be2)
1544
1545        event.listen(Engine, "before_execute", canary.be3)
1546
1547        with e1.connect() as conn:
1548            conn.execute(select(1))
1549        eq_(canary.be1.call_count, 1)
1550        eq_(canary.be2.call_count, 1)
1551
1552        with e2.connect() as conn:
1553            conn.execute(select(1))
1554
1555        eq_(canary.be1.call_count, 2)
1556        eq_(canary.be2.call_count, 1)
1557        eq_(canary.be3.call_count, 2)
1558
1559    def test_emit_sql_in_autobegin(self, testing_engine):
1560        e1 = testing_engine(config.db_url)
1561
1562        canary = Mock()
1563
1564        @event.listens_for(e1, "begin")
1565        def begin(connection):
1566            result = connection.execute(select(1)).scalar()
1567            canary.got_result(result)
1568
1569        with e1.connect() as conn:
1570            assert not conn._is_future
1571
1572            with conn.begin():
1573                conn.execute(select(1)).scalar()
1574                assert conn.in_transaction()
1575
1576            assert not conn.in_transaction()
1577
1578        eq_(canary.mock_calls, [call.got_result(1)])
1579
1580    def test_per_connection_plus_engine(self, testing_engine):
1581        canary = Mock()
1582        e1 = testing_engine(config.db_url)
1583
1584        event.listen(e1, "before_execute", canary.be1)
1585
1586        conn = e1.connect()
1587        event.listen(conn, "before_execute", canary.be2)
1588        conn.execute(select(1))
1589
1590        eq_(canary.be1.call_count, 1)
1591        eq_(canary.be2.call_count, 1)
1592
1593        if testing.requires.legacy_engine.enabled:
1594            conn._branch().execute(select(1))
1595            eq_(canary.be1.call_count, 2)
1596            eq_(canary.be2.call_count, 2)
1597
1598    @testing.combinations(
1599        (True, False),
1600        (True, True),
1601        (False, False),
1602        argnames="mock_out_on_connect, add_our_own_onconnect",
1603    )
1604    def test_insert_connect_is_definitely_first(
1605        self, mock_out_on_connect, add_our_own_onconnect, testing_engine
1606    ):
1607        """test issue #5708.
1608
1609        We want to ensure that a single "connect" event may be invoked
1610        *before* dialect initialize as well as before dialect on_connects.
1611
1612        This is also partially reliant on the changes we made as a result of
1613        #5497, however here we go further with the changes and remove use
1614        of the pool first_connect() event entirely so that the startup
1615        for a dialect is fully consistent.
1616
1617        """
1618        if mock_out_on_connect:
1619            if add_our_own_onconnect:
1620
1621                def our_connect(connection):
1622                    m1.our_connect("our connect event")
1623
1624                patcher = mock.patch.object(
1625                    config.db.dialect.__class__,
1626                    "on_connect",
1627                    lambda self: our_connect,
1628                )
1629            else:
1630                patcher = mock.patch.object(
1631                    config.db.dialect.__class__,
1632                    "on_connect",
1633                    lambda self: None,
1634                )
1635        else:
1636            patcher = util.nullcontext()
1637
1638        with patcher:
1639            e1 = testing_engine(config.db_url)
1640
1641            initialize = e1.dialect.initialize
1642
1643            def init(connection):
1644                initialize(connection)
1645
1646            with mock.patch.object(
1647                e1.dialect, "initialize", side_effect=init
1648            ) as m1:
1649
1650                @event.listens_for(e1, "connect", insert=True)
1651                def go1(dbapi_conn, xyz):
1652                    m1.foo("custom event first")
1653
1654                @event.listens_for(e1, "connect")
1655                def go2(dbapi_conn, xyz):
1656                    m1.foo("custom event last")
1657
1658                c1 = e1.connect()
1659
1660                m1.bar("ok next connection")
1661
1662                c2 = e1.connect()
1663
1664                # this happens with sqlite singletonthreadpool.
1665                # we can almost use testing.requires.independent_connections
1666                # but sqlite file backend will also have independent
1667                # connections here.
1668                its_the_same_connection = (
1669                    c1.connection.dbapi_connection
1670                    is c2.connection.dbapi_connection
1671                )
1672                c1.close()
1673                c2.close()
1674
1675        if add_our_own_onconnect:
1676            calls = [
1677                mock.call.foo("custom event first"),
1678                mock.call.our_connect("our connect event"),
1679                mock.call(mock.ANY),
1680                mock.call.foo("custom event last"),
1681                mock.call.bar("ok next connection"),
1682            ]
1683        else:
1684            calls = [
1685                mock.call.foo("custom event first"),
1686                mock.call(mock.ANY),
1687                mock.call.foo("custom event last"),
1688                mock.call.bar("ok next connection"),
1689            ]
1690
1691        if not its_the_same_connection:
1692            if add_our_own_onconnect:
1693                calls.extend(
1694                    [
1695                        mock.call.foo("custom event first"),
1696                        mock.call.our_connect("our connect event"),
1697                        mock.call.foo("custom event last"),
1698                    ]
1699                )
1700            else:
1701                calls.extend(
1702                    [
1703                        mock.call.foo("custom event first"),
1704                        mock.call.foo("custom event last"),
1705                    ]
1706                )
1707        eq_(m1.mock_calls, calls)
1708
1709    def test_new_exec_driver_sql_no_events(self):
1710        m1 = Mock()
1711
1712        def select1(db):
1713            return str(select(1).compile(dialect=db.dialect))
1714
1715        with testing.db.connect() as conn:
1716            event.listen(conn, "before_execute", m1.before_execute)
1717            event.listen(conn, "after_execute", m1.after_execute)
1718            conn.exec_driver_sql(select1(testing.db))
1719        eq_(m1.mock_calls, [])
1720
1721    def test_add_event_after_connect(self, testing_engine):
1722        # new feature as of #2978
1723
1724        canary = Mock()
1725        e1 = testing_engine(config.db_url, future=False)
1726        assert not e1._has_events
1727
1728        conn = e1.connect()
1729
1730        event.listen(e1, "before_execute", canary.be1)
1731        conn.execute(select(1))
1732
1733        eq_(canary.be1.call_count, 1)
1734
1735        conn._branch().execute(select(1))
1736        eq_(canary.be1.call_count, 2)
1737
1738    def test_force_conn_events_false(self, testing_engine):
1739        canary = Mock()
1740        e1 = testing_engine(config.db_url, future=False)
1741        assert not e1._has_events
1742
1743        event.listen(e1, "before_execute", canary.be1)
1744
1745        conn = e1._connection_cls(
1746            e1, connection=e1.raw_connection(), _has_events=False
1747        )
1748
1749        conn.execute(select(1))
1750
1751        eq_(canary.be1.call_count, 0)
1752
1753        conn._branch().execute(select(1))
1754        eq_(canary.be1.call_count, 0)
1755
1756    def test_cursor_events_ctx_execute_scalar(self, testing_engine):
1757        canary = Mock()
1758        e1 = testing_engine(config.db_url)
1759
1760        event.listen(e1, "before_cursor_execute", canary.bce)
1761        event.listen(e1, "after_cursor_execute", canary.ace)
1762
1763        stmt = str(select(1).compile(dialect=e1.dialect))
1764
1765        with e1.connect() as conn:
1766            dialect = conn.dialect
1767
1768            ctx = dialect.execution_ctx_cls._init_statement(
1769                dialect, conn, conn.connection, {}, stmt, {}
1770            )
1771
1772            ctx._execute_scalar(stmt, Integer())
1773
1774        eq_(
1775            canary.bce.mock_calls,
1776            [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)],
1777        )
1778        eq_(
1779            canary.ace.mock_calls,
1780            [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)],
1781        )
1782
1783    def test_cursor_events_execute(self, testing_engine):
1784        canary = Mock()
1785        e1 = testing_engine(config.db_url)
1786
1787        event.listen(e1, "before_cursor_execute", canary.bce)
1788        event.listen(e1, "after_cursor_execute", canary.ace)
1789
1790        stmt = str(select(1).compile(dialect=e1.dialect))
1791
1792        with e1.connect() as conn:
1793
1794            result = conn.exec_driver_sql(stmt)
1795            eq_(result.scalar(), 1)
1796
1797        ctx = result.context
1798        eq_(
1799            canary.bce.mock_calls,
1800            [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)],
1801        )
1802        eq_(
1803            canary.ace.mock_calls,
1804            [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)],
1805        )
1806
1807    @testing.combinations(
1808        (
1809            ([{"x": 5, "y": 10}, {"x": 8, "y": 9}],),
1810            {},
1811            [{"x": 5, "y": 10}, {"x": 8, "y": 9}],
1812            {},
1813        ),
1814        (({"z": 10},), {}, [], {"z": 10}),
1815        argnames="multiparams, params, expected_multiparams, expected_params",
1816    )
1817    def test_modify_parameters_from_event_one(
1818        self,
1819        multiparams,
1820        params,
1821        expected_multiparams,
1822        expected_params,
1823        testing_engine,
1824    ):
1825        # this is testing both the normalization added to parameters
1826        # as of I97cb4d06adfcc6b889f10d01cc7775925cffb116 as well as
1827        # that the return value from the event is taken as the new set
1828        # of parameters.
1829        def before_execute(
1830            conn, clauseelement, multiparams, params, execution_options
1831        ):
1832            eq_(multiparams, expected_multiparams)
1833            eq_(params, expected_params)
1834            return clauseelement, (), {"q": "15"}
1835
1836        def after_execute(
1837            conn, clauseelement, multiparams, params, result, execution_options
1838        ):
1839            eq_(multiparams, ())
1840            eq_(params, {"q": "15"})
1841
1842        e1 = testing_engine(config.db_url)
1843        event.listen(e1, "before_execute", before_execute, retval=True)
1844        event.listen(e1, "after_execute", after_execute)
1845
1846        with e1.connect() as conn:
1847            result = conn.execute(
1848                select(bindparam("q", type_=String)), *multiparams, **params
1849            )
1850            eq_(result.all(), [("15",)])
1851
1852    @testing.provide_metadata
1853    def test_modify_parameters_from_event_two(self, connection):
1854        t = Table("t", self.metadata, Column("q", Integer))
1855
1856        t.create(connection)
1857
1858        def before_execute(
1859            conn, clauseelement, multiparams, params, execution_options
1860        ):
1861            return clauseelement, [{"q": 15}, {"q": 19}], {}
1862
1863        event.listen(connection, "before_execute", before_execute, retval=True)
1864        connection.execute(t.insert(), {"q": 12})
1865        event.remove(connection, "before_execute", before_execute)
1866
1867        eq_(
1868            connection.execute(select(t).order_by(t.c.q)).fetchall(),
1869            [(15,), (19,)],
1870        )
1871
1872    def test_modify_parameters_from_event_three(
1873        self, connection, testing_engine
1874    ):
1875        def before_execute(
1876            conn, clauseelement, multiparams, params, execution_options
1877        ):
1878            return clauseelement, [{"q": 15}, {"q": 19}], {"q": 7}
1879
1880        e1 = testing_engine(config.db_url)
1881        event.listen(e1, "before_execute", before_execute, retval=True)
1882
1883        with expect_raises_message(
1884            tsa.exc.InvalidRequestError,
1885            "Event handler can't return non-empty multiparams "
1886            "and params at the same time",
1887        ):
1888            with e1.connect() as conn:
1889                conn.execute(select(literal("1")))
1890
1891    @testing.only_on("sqlite")
1892    def test_dont_modify_statement_driversql(self, connection):
1893        m1 = mock.Mock()
1894
1895        @event.listens_for(connection, "before_execute", retval=True)
1896        def _modify(
1897            conn, clauseelement, multiparams, params, execution_options
1898        ):
1899            m1.run_event()
1900            return clauseelement.replace("hi", "there"), multiparams, params
1901
1902        # the event does not take effect for the "driver SQL" option
1903        eq_(connection.exec_driver_sql("select 'hi'").scalar(), "hi")
1904
1905        # event is not called at all
1906        eq_(m1.mock_calls, [])
1907
1908    @testing.combinations((True,), (False,), argnames="future")
1909    @testing.only_on("sqlite")
1910    def test_modify_statement_internal_driversql(self, connection, future):
1911        m1 = mock.Mock()
1912
1913        @event.listens_for(connection, "before_execute", retval=True)
1914        def _modify(
1915            conn, clauseelement, multiparams, params, execution_options
1916        ):
1917            m1.run_event()
1918            return clauseelement.replace("hi", "there"), multiparams, params
1919
1920        eq_(
1921            connection._exec_driver_sql(
1922                "select 'hi'", [], {}, {}, future=future
1923            ).scalar(),
1924            "hi" if future else "there",
1925        )
1926
1927        if future:
1928            eq_(m1.mock_calls, [])
1929        else:
1930            eq_(m1.mock_calls, [call.run_event()])
1931
1932    def test_modify_statement_clauseelement(self, connection):
1933        @event.listens_for(connection, "before_execute", retval=True)
1934        def _modify(
1935            conn, clauseelement, multiparams, params, execution_options
1936        ):
1937            return select(literal_column("'there'")), multiparams, params
1938
1939        eq_(connection.scalar(select(literal_column("'hi'"))), "there")
1940
1941    def test_argument_format_execute(self, testing_engine):
1942        def before_execute(
1943            conn, clauseelement, multiparams, params, execution_options
1944        ):
1945            assert isinstance(multiparams, (list, tuple))
1946            assert isinstance(params, collections_abc.Mapping)
1947
1948        def after_execute(
1949            conn, clauseelement, multiparams, params, result, execution_options
1950        ):
1951            assert isinstance(multiparams, (list, tuple))
1952            assert isinstance(params, collections_abc.Mapping)
1953
1954        e1 = testing_engine(config.db_url)
1955        event.listen(e1, "before_execute", before_execute)
1956        event.listen(e1, "after_execute", after_execute)
1957
1958        with e1.connect() as conn:
1959            conn.execute(select(1))
1960            conn.execute(select(1).compile(dialect=e1.dialect).statement)
1961            conn.execute(select(1).compile(dialect=e1.dialect))
1962
1963            conn._execute_compiled(
1964                select(1).compile(dialect=e1.dialect), (), {}, {}
1965            )
1966
1967    def test_execute_events(self):
1968
1969        stmts = []
1970        cursor_stmts = []
1971
1972        def execute(
1973            conn, clauseelement, multiparams, params, execution_options
1974        ):
1975            stmts.append((str(clauseelement), params, multiparams))
1976
1977        def cursor_execute(
1978            conn, cursor, statement, parameters, context, executemany
1979        ):
1980            cursor_stmts.append((str(statement), parameters, None))
1981
1982        # TODO: this test is kind of a mess
1983
1984        for engine in [
1985            engines.testing_engine(options=dict(implicit_returning=False)),
1986            engines.testing_engine(
1987                options=dict(implicit_returning=False)
1988            ).connect(),
1989        ]:
1990            event.listen(engine, "before_execute", execute)
1991            event.listen(engine, "before_cursor_execute", cursor_execute)
1992            m = MetaData()
1993            t1 = Table(
1994                "t1",
1995                m,
1996                Column("c1", Integer, primary_key=True),
1997                Column(
1998                    "c2",
1999                    String(50),
2000                    default=func.lower("Foo"),
2001                    primary_key=True,
2002                ),
2003            )
2004
2005            if isinstance(engine, Connection):
2006                ctx = None
2007                conn = engine
2008            else:
2009                ctx = conn = engine.connect()
2010
2011            trans = conn.begin()
2012            try:
2013                m.create_all(conn, checkfirst=False)
2014                try:
2015                    conn.execute(t1.insert(), dict(c1=5, c2="some data"))
2016                    conn.execute(t1.insert(), dict(c1=6))
2017                    eq_(
2018                        conn.execute(text("select * from t1")).fetchall(),
2019                        [(5, "some data"), (6, "foo")],
2020                    )
2021                finally:
2022                    m.drop_all(conn)
2023                    trans.commit()
2024            finally:
2025                if ctx:
2026                    ctx.close()
2027
2028            compiled = [
2029                ("CREATE TABLE t1", {}, None),
2030                (
2031                    "INSERT INTO t1 (c1, c2)",
2032                    {"c2": "some data", "c1": 5},
2033                    (),
2034                ),
2035                ("INSERT INTO t1 (c1, c2)", {"c1": 6}, ()),
2036                ("select * from t1", {}, None),
2037                ("DROP TABLE t1", {}, None),
2038            ]
2039
2040            cursor = [
2041                ("CREATE TABLE t1", {}, ()),
2042                (
2043                    "INSERT INTO t1 (c1, c2)",
2044                    {"c2": "some data", "c1": 5},
2045                    (5, "some data"),
2046                ),
2047                ("SELECT lower", {"lower_2": "Foo"}, ("Foo",)),
2048                (
2049                    "INSERT INTO t1 (c1, c2)",
2050                    {"c2": "foo", "c1": 6},
2051                    (6, "foo"),
2052                ),
2053                ("select * from t1", {}, ()),
2054                ("DROP TABLE t1", {}, ()),
2055            ]
2056            self._assert_stmts(compiled, stmts)
2057            self._assert_stmts(cursor, cursor_stmts)
2058
2059    def test_options(self):
2060        canary = []
2061
2062        def execute(conn, *args, **kw):
2063            canary.append("execute")
2064
2065        def cursor_execute(conn, *args, **kw):
2066            canary.append("cursor_execute")
2067
2068        engine = engines.testing_engine()
2069        event.listen(engine, "before_execute", execute)
2070        event.listen(engine, "before_cursor_execute", cursor_execute)
2071        conn = engine.connect()
2072        c2 = conn.execution_options(foo="bar")
2073        eq_(c2._execution_options, {"foo": "bar"})
2074        c2.execute(select(1))
2075        c3 = c2.execution_options(bar="bat")
2076        eq_(c3._execution_options, {"foo": "bar", "bar": "bat"})
2077        eq_(canary, ["execute", "cursor_execute"])
2078
2079    @testing.requires.ad_hoc_engines
2080    def test_generative_engine_event_dispatch(self):
2081        canary = []
2082
2083        def l1(*arg, **kw):
2084            canary.append("l1")
2085
2086        def l2(*arg, **kw):
2087            canary.append("l2")
2088
2089        def l3(*arg, **kw):
2090            canary.append("l3")
2091
2092        eng = engines.testing_engine(
2093            options={"execution_options": {"base": "x1"}}
2094        )
2095        event.listen(eng, "before_execute", l1)
2096
2097        eng1 = eng.execution_options(foo="b1")
2098        event.listen(eng, "before_execute", l2)
2099        event.listen(eng1, "before_execute", l3)
2100
2101        with eng.connect() as conn:
2102            conn.execute(select(1))
2103
2104        eq_(canary, ["l1", "l2"])
2105
2106        with eng1.connect() as conn:
2107            conn.execute(select(1))
2108
2109        eq_(canary, ["l1", "l2", "l3", "l1", "l2"])
2110
2111    @testing.requires.ad_hoc_engines
2112    def test_clslevel_engine_event_options(self):
2113        canary = []
2114
2115        def l1(*arg, **kw):
2116            canary.append("l1")
2117
2118        def l2(*arg, **kw):
2119            canary.append("l2")
2120
2121        def l3(*arg, **kw):
2122            canary.append("l3")
2123
2124        def l4(*arg, **kw):
2125            canary.append("l4")
2126
2127        event.listen(Engine, "before_execute", l1)
2128
2129        eng = engines.testing_engine(
2130            options={"execution_options": {"base": "x1"}}
2131        )
2132        event.listen(eng, "before_execute", l2)
2133
2134        eng1 = eng.execution_options(foo="b1")
2135        event.listen(eng, "before_execute", l3)
2136        event.listen(eng1, "before_execute", l4)
2137
2138        with eng.connect() as conn:
2139            conn.execute(select(1))
2140
2141        eq_(canary, ["l1", "l2", "l3"])
2142
2143        with eng1.connect() as conn:
2144            conn.execute(select(1))
2145
2146        eq_(canary, ["l1", "l2", "l3", "l4", "l1", "l2", "l3"])
2147
2148        canary[:] = []
2149
2150        event.remove(Engine, "before_execute", l1)
2151        event.remove(eng1, "before_execute", l4)
2152        event.remove(eng, "before_execute", l3)
2153
2154        with eng1.connect() as conn:
2155            conn.execute(select(1))
2156        eq_(canary, ["l2"])
2157
2158    @testing.requires.ad_hoc_engines
2159    def test_cant_listen_to_option_engine(self):
2160        from sqlalchemy.engine import base
2161
2162        def evt(*arg, **kw):
2163            pass
2164
2165        assert_raises_message(
2166            tsa.exc.InvalidRequestError,
2167            r"Can't assign an event directly to the "
2168            "<class 'sqlalchemy.engine.base.OptionEngine'> class",
2169            event.listen,
2170            base.OptionEngine,
2171            "before_cursor_execute",
2172            evt,
2173        )
2174
2175    @testing.requires.ad_hoc_engines
2176    def test_dispose_event(self, testing_engine):
2177        canary = Mock()
2178        eng = testing_engine(testing.db.url)
2179        event.listen(eng, "engine_disposed", canary)
2180
2181        conn = eng.connect()
2182        conn.close()
2183        eng.dispose()
2184
2185        conn = eng.connect()
2186        conn.close()
2187
2188        eq_(canary.mock_calls, [call(eng)])
2189
2190        eng.dispose()
2191
2192        eq_(canary.mock_calls, [call(eng), call(eng)])
2193
2194    def test_retval_flag(self):
2195        canary = []
2196
2197        def tracker(name):
2198            def go(conn, *args, **kw):
2199                canary.append(name)
2200
2201            return go
2202
2203        def execute(
2204            conn, clauseelement, multiparams, params, execution_options
2205        ):
2206            canary.append("execute")
2207            return clauseelement, multiparams, params
2208
2209        def cursor_execute(
2210            conn, cursor, statement, parameters, context, executemany
2211        ):
2212            canary.append("cursor_execute")
2213            return statement, parameters
2214
2215        engine = engines.testing_engine()
2216
2217        assert_raises(
2218            tsa.exc.ArgumentError,
2219            event.listen,
2220            engine,
2221            "begin",
2222            tracker("begin"),
2223            retval=True,
2224        )
2225
2226        event.listen(engine, "before_execute", execute, retval=True)
2227        event.listen(
2228            engine, "before_cursor_execute", cursor_execute, retval=True
2229        )
2230        with engine.connect() as conn:
2231            conn.execute(select(1))
2232        eq_(canary, ["execute", "cursor_execute"])
2233
2234    @testing.requires.legacy_engine
2235    def test_engine_connect(self):
2236        engine = engines.testing_engine()
2237
2238        tracker = Mock()
2239        event.listen(engine, "engine_connect", tracker)
2240
2241        c1 = engine.connect()
2242        c2 = c1._branch()
2243        c1.close()
2244        eq_(tracker.mock_calls, [call(c1, False), call(c2, True)])
2245
2246    def test_execution_options(self):
2247        engine = engines.testing_engine()
2248
2249        engine_tracker = Mock()
2250        conn_tracker = Mock()
2251
2252        event.listen(engine, "set_engine_execution_options", engine_tracker)
2253        event.listen(engine, "set_connection_execution_options", conn_tracker)
2254
2255        e2 = engine.execution_options(e1="opt_e1")
2256        c1 = engine.connect()
2257        c2 = c1.execution_options(c1="opt_c1")
2258        c3 = e2.connect()
2259        c4 = c3.execution_options(c3="opt_c3")
2260        eq_(engine_tracker.mock_calls, [call(e2, {"e1": "opt_e1"})])
2261        eq_(
2262            conn_tracker.mock_calls,
2263            [call(c2, {"c1": "opt_c1"}), call(c4, {"c3": "opt_c3"})],
2264        )
2265
2266    @testing.requires.sequences
2267    @testing.provide_metadata
2268    def test_cursor_execute(self):
2269        canary = []
2270
2271        def tracker(name):
2272            def go(conn, cursor, statement, parameters, context, executemany):
2273                canary.append((statement, context))
2274
2275            return go
2276
2277        engine = engines.testing_engine()
2278
2279        t = Table(
2280            "t",
2281            self.metadata,
2282            Column(
2283                "x",
2284                Integer,
2285                Sequence("t_id_seq"),
2286                primary_key=True,
2287            ),
2288            implicit_returning=False,
2289        )
2290        self.metadata.create_all(engine)
2291
2292        with engine.begin() as conn:
2293            event.listen(
2294                conn, "before_cursor_execute", tracker("cursor_execute")
2295            )
2296            conn.execute(t.insert())
2297
2298        # we see the sequence pre-executed in the first call
2299        assert "t_id_seq" in canary[0][0]
2300        assert "INSERT" in canary[1][0]
2301        # same context
2302        is_(canary[0][1], canary[1][1])
2303
2304    def test_transactional(self):
2305        canary = []
2306
2307        def tracker(name):
2308            def go(conn, *args, **kw):
2309                canary.append(name)
2310
2311            return go
2312
2313        engine = engines.testing_engine()
2314        event.listen(engine, "before_execute", tracker("execute"))
2315        event.listen(
2316            engine, "before_cursor_execute", tracker("cursor_execute")
2317        )
2318        event.listen(engine, "begin", tracker("begin"))
2319        event.listen(engine, "commit", tracker("commit"))
2320        event.listen(engine, "rollback", tracker("rollback"))
2321
2322        with engine.connect() as conn:
2323            trans = conn.begin()
2324            conn.execute(select(1))
2325            trans.rollback()
2326            trans = conn.begin()
2327            conn.execute(select(1))
2328            trans.commit()
2329
2330        eq_(
2331            canary,
2332            [
2333                "begin",
2334                "execute",
2335                "cursor_execute",
2336                "rollback",
2337                "begin",
2338                "execute",
2339                "cursor_execute",
2340                "commit",
2341            ],
2342        )
2343
2344    def test_transactional_named(self):
2345        canary = []
2346
2347        def tracker(name):
2348            def go(*args, **kw):
2349                canary.append((name, set(kw)))
2350
2351            return go
2352
2353        engine = engines.testing_engine()
2354        event.listen(engine, "before_execute", tracker("execute"), named=True)
2355        event.listen(
2356            engine,
2357            "before_cursor_execute",
2358            tracker("cursor_execute"),
2359            named=True,
2360        )
2361        event.listen(engine, "begin", tracker("begin"), named=True)
2362        event.listen(engine, "commit", tracker("commit"), named=True)
2363        event.listen(engine, "rollback", tracker("rollback"), named=True)
2364
2365        with engine.connect() as conn:
2366            trans = conn.begin()
2367            conn.execute(select(1))
2368            trans.rollback()
2369            trans = conn.begin()
2370            conn.execute(select(1))
2371            trans.commit()
2372
2373        eq_(
2374            canary,
2375            [
2376                ("begin", set(["conn"])),
2377                (
2378                    "execute",
2379                    set(
2380                        [
2381                            "conn",
2382                            "clauseelement",
2383                            "multiparams",
2384                            "params",
2385                            "execution_options",
2386                        ]
2387                    ),
2388                ),
2389                (
2390                    "cursor_execute",
2391                    set(
2392                        [
2393                            "conn",
2394                            "cursor",
2395                            "executemany",
2396                            "statement",
2397                            "parameters",
2398                            "context",
2399                        ]
2400                    ),
2401                ),
2402                ("rollback", set(["conn"])),
2403                ("begin", set(["conn"])),
2404                (
2405                    "execute",
2406                    set(
2407                        [
2408                            "conn",
2409                            "clauseelement",
2410                            "multiparams",
2411                            "params",
2412                            "execution_options",
2413                        ]
2414                    ),
2415                ),
2416                (
2417                    "cursor_execute",
2418                    set(
2419                        [
2420                            "conn",
2421                            "cursor",
2422                            "executemany",
2423                            "statement",
2424                            "parameters",
2425                            "context",
2426                        ]
2427                    ),
2428                ),
2429                ("commit", set(["conn"])),
2430            ],
2431        )
2432
2433    @testing.requires.savepoints
2434    @testing.requires.two_phase_transactions
2435    def test_transactional_advanced(self):
2436        canary1 = []
2437
2438        def tracker1(name):
2439            def go(*args, **kw):
2440                canary1.append(name)
2441
2442            return go
2443
2444        canary2 = []
2445
2446        def tracker2(name):
2447            def go(*args, **kw):
2448                canary2.append(name)
2449
2450            return go
2451
2452        engine = engines.testing_engine()
2453        for name in [
2454            "begin",
2455            "savepoint",
2456            "rollback_savepoint",
2457            "release_savepoint",
2458            "rollback",
2459            "begin_twophase",
2460            "prepare_twophase",
2461            "commit_twophase",
2462        ]:
2463            event.listen(engine, "%s" % name, tracker1(name))
2464
2465        conn = engine.connect()
2466        for name in [
2467            "begin",
2468            "savepoint",
2469            "rollback_savepoint",
2470            "release_savepoint",
2471            "rollback",
2472            "begin_twophase",
2473            "prepare_twophase",
2474            "commit_twophase",
2475        ]:
2476            event.listen(conn, "%s" % name, tracker2(name))
2477
2478        trans = conn.begin()
2479        trans2 = conn.begin_nested()
2480        conn.execute(select(1))
2481        trans2.rollback()
2482        trans2 = conn.begin_nested()
2483        conn.execute(select(1))
2484        trans2.commit()
2485        trans.rollback()
2486
2487        trans = conn.begin_twophase()
2488        conn.execute(select(1))
2489        trans.prepare()
2490        trans.commit()
2491
2492        eq_(
2493            canary1,
2494            [
2495                "begin",
2496                "savepoint",
2497                "rollback_savepoint",
2498                "savepoint",
2499                "release_savepoint",
2500                "rollback",
2501                "begin_twophase",
2502                "prepare_twophase",
2503                "commit_twophase",
2504            ],
2505        )
2506        eq_(
2507            canary2,
2508            [
2509                "begin",
2510                "savepoint",
2511                "rollback_savepoint",
2512                "savepoint",
2513                "release_savepoint",
2514                "rollback",
2515                "begin_twophase",
2516                "prepare_twophase",
2517                "commit_twophase",
2518            ],
2519        )
2520
2521
2522class FutureEngineEventsTest(fixtures.FutureEngineMixin, EngineEventsTest):
2523    def test_future_fixture(self, testing_engine):
2524        e1 = testing_engine()
2525
2526        assert e1._is_future
2527        with e1.connect() as conn:
2528            assert conn._is_future
2529
2530    def test_emit_sql_in_autobegin(self, testing_engine):
2531        e1 = testing_engine(config.db_url)
2532
2533        canary = Mock()
2534
2535        @event.listens_for(e1, "begin")
2536        def begin(connection):
2537            result = connection.execute(select(1)).scalar()
2538            canary.got_result(result)
2539
2540        with e1.connect() as conn:
2541            assert conn._is_future
2542            conn.execute(select(1)).scalar()
2543
2544            assert conn.in_transaction()
2545
2546            conn.commit()
2547
2548            assert not conn.in_transaction()
2549
2550        eq_(canary.mock_calls, [call.got_result(1)])
2551
2552
2553class HandleErrorTest(fixtures.TestBase):
2554    __requires__ = ("ad_hoc_engines",)
2555    __backend__ = True
2556
2557    def teardown_test(self):
2558        Engine.dispatch._clear()
2559        Engine._has_events = False
2560
2561    def test_handle_error(self):
2562        engine = engines.testing_engine()
2563        canary = Mock(return_value=None)
2564
2565        event.listen(engine, "handle_error", canary)
2566
2567        with engine.connect() as conn:
2568            try:
2569                conn.exec_driver_sql("SELECT FOO FROM I_DONT_EXIST")
2570                assert False
2571            except tsa.exc.DBAPIError as e:
2572                ctx = canary.mock_calls[0][1][0]
2573
2574                eq_(ctx.original_exception, e.orig)
2575                is_(ctx.sqlalchemy_exception, e)
2576                eq_(ctx.statement, "SELECT FOO FROM I_DONT_EXIST")
2577
2578    def test_exception_event_reraise(self):
2579        engine = engines.testing_engine()
2580
2581        class MyException(Exception):
2582            pass
2583
2584        @event.listens_for(engine, "handle_error", retval=True)
2585        def err(context):
2586            stmt = context.statement
2587            exception = context.original_exception
2588            if "ERROR ONE" in str(stmt):
2589                return MyException("my exception")
2590            elif "ERROR TWO" in str(stmt):
2591                return exception
2592            else:
2593                return None
2594
2595        conn = engine.connect()
2596        # case 1: custom exception
2597        assert_raises_message(
2598            MyException,
2599            "my exception",
2600            conn.exec_driver_sql,
2601            "SELECT 'ERROR ONE' FROM I_DONT_EXIST",
2602        )
2603        # case 2: return the DBAPI exception we're given;
2604        # no wrapping should occur
2605        assert_raises(
2606            conn.dialect.dbapi.Error,
2607            conn.exec_driver_sql,
2608            "SELECT 'ERROR TWO' FROM I_DONT_EXIST",
2609        )
2610        # case 3: normal wrapping
2611        assert_raises(
2612            tsa.exc.DBAPIError,
2613            conn.exec_driver_sql,
2614            "SELECT 'ERROR THREE' FROM I_DONT_EXIST",
2615        )
2616
2617    def test_exception_event_reraise_chaining(self):
2618        engine = engines.testing_engine()
2619
2620        class MyException1(Exception):
2621            pass
2622
2623        class MyException2(Exception):
2624            pass
2625
2626        class MyException3(Exception):
2627            pass
2628
2629        @event.listens_for(engine, "handle_error", retval=True)
2630        def err1(context):
2631            stmt = context.statement
2632
2633            if (
2634                "ERROR ONE" in str(stmt)
2635                or "ERROR TWO" in str(stmt)
2636                or "ERROR THREE" in str(stmt)
2637            ):
2638                return MyException1("my exception")
2639            elif "ERROR FOUR" in str(stmt):
2640                raise MyException3("my exception short circuit")
2641
2642        @event.listens_for(engine, "handle_error", retval=True)
2643        def err2(context):
2644            stmt = context.statement
2645            if (
2646                "ERROR ONE" in str(stmt) or "ERROR FOUR" in str(stmt)
2647            ) and isinstance(context.chained_exception, MyException1):
2648                raise MyException2("my exception chained")
2649            elif "ERROR TWO" in str(stmt):
2650                return context.chained_exception
2651            else:
2652                return None
2653
2654        conn = engine.connect()
2655
2656        with patch.object(
2657            engine.dialect.execution_ctx_cls, "handle_dbapi_exception"
2658        ) as patched:
2659            assert_raises_message(
2660                MyException2,
2661                "my exception chained",
2662                conn.exec_driver_sql,
2663                "SELECT 'ERROR ONE' FROM I_DONT_EXIST",
2664            )
2665            eq_(patched.call_count, 1)
2666
2667        with patch.object(
2668            engine.dialect.execution_ctx_cls, "handle_dbapi_exception"
2669        ) as patched:
2670            assert_raises(
2671                MyException1,
2672                conn.exec_driver_sql,
2673                "SELECT 'ERROR TWO' FROM I_DONT_EXIST",
2674            )
2675            eq_(patched.call_count, 1)
2676
2677        with patch.object(
2678            engine.dialect.execution_ctx_cls, "handle_dbapi_exception"
2679        ) as patched:
2680            # test that non None from err1 isn't cancelled out
2681            # by err2
2682            assert_raises(
2683                MyException1,
2684                conn.exec_driver_sql,
2685                "SELECT 'ERROR THREE' FROM I_DONT_EXIST",
2686            )
2687            eq_(patched.call_count, 1)
2688
2689        with patch.object(
2690            engine.dialect.execution_ctx_cls, "handle_dbapi_exception"
2691        ) as patched:
2692            assert_raises(
2693                tsa.exc.DBAPIError,
2694                conn.exec_driver_sql,
2695                "SELECT 'ERROR FIVE' FROM I_DONT_EXIST",
2696            )
2697            eq_(patched.call_count, 1)
2698
2699        with patch.object(
2700            engine.dialect.execution_ctx_cls, "handle_dbapi_exception"
2701        ) as patched:
2702            assert_raises_message(
2703                MyException3,
2704                "my exception short circuit",
2705                conn.exec_driver_sql,
2706                "SELECT 'ERROR FOUR' FROM I_DONT_EXIST",
2707            )
2708            eq_(patched.call_count, 1)
2709
2710    def test_exception_autorollback_fails(self):
2711        engine = engines.testing_engine()
2712        conn = engine.connect()
2713
2714        def boom(connection):
2715            raise engine.dialect.dbapi.OperationalError("rollback failed")
2716
2717        with expect_warnings(
2718            r"An exception has occurred during handling of a previous "
2719            r"exception.  The previous exception "
2720            r"is.*(?:i_dont_exist|does not exist)",
2721            py2konly=True,
2722        ):
2723            with patch.object(conn.dialect, "do_rollback", boom):
2724                assert_raises_message(
2725                    tsa.exc.OperationalError,
2726                    "rollback failed",
2727                    conn.exec_driver_sql,
2728                    "insert into i_dont_exist (x) values ('y')",
2729                )
2730
2731    def test_exception_event_ad_hoc_context(self):
2732        """test that handle_error is called with a context in
2733        cases where _handle_dbapi_error() is normally called without
2734        any context.
2735
2736        """
2737
2738        engine = engines.testing_engine()
2739
2740        listener = Mock(return_value=None)
2741        event.listen(engine, "handle_error", listener)
2742
2743        nope = SomeException("nope")
2744
2745        class MyType(TypeDecorator):
2746            impl = Integer
2747            cache_ok = True
2748
2749            def process_bind_param(self, value, dialect):
2750                raise nope
2751
2752        with engine.connect() as conn:
2753            assert_raises_message(
2754                tsa.exc.StatementError,
2755                r"\(.*.SomeException\) " r"nope\n\[SQL\: u?SELECT 1 ",
2756                conn.execute,
2757                select(1).where(column("foo") == literal("bar", MyType())),
2758            )
2759
2760        ctx = listener.mock_calls[0][1][0]
2761        assert ctx.statement.startswith("SELECT 1 ")
2762        is_(ctx.is_disconnect, False)
2763        is_(ctx.original_exception, nope)
2764
2765    def test_exception_event_non_dbapi_error(self):
2766        """test that handle_error is called with a context in
2767        cases where DBAPI raises an exception that is not a DBAPI
2768        exception, e.g. internal errors or encoding problems.
2769
2770        """
2771        engine = engines.testing_engine()
2772
2773        listener = Mock(return_value=None)
2774        event.listen(engine, "handle_error", listener)
2775
2776        nope = TypeError("I'm not a DBAPI error")
2777        with engine.connect() as c:
2778            c.connection.cursor = Mock(
2779                return_value=Mock(execute=Mock(side_effect=nope))
2780            )
2781
2782            assert_raises_message(
2783                TypeError,
2784                "I'm not a DBAPI error",
2785                c.exec_driver_sql,
2786                "select ",
2787            )
2788        ctx = listener.mock_calls[0][1][0]
2789        eq_(ctx.statement, "select ")
2790        is_(ctx.is_disconnect, False)
2791        is_(ctx.original_exception, nope)
2792
2793    def test_exception_event_disable_handlers(self):
2794        engine = engines.testing_engine()
2795
2796        class MyException1(Exception):
2797            pass
2798
2799        @event.listens_for(engine, "handle_error")
2800        def err1(context):
2801            stmt = context.statement
2802
2803            if "ERROR_ONE" in str(stmt):
2804                raise MyException1("my exception short circuit")
2805
2806        with engine.connect() as conn:
2807            assert_raises(
2808                tsa.exc.DBAPIError,
2809                conn.execution_options(
2810                    skip_user_error_events=True
2811                ).exec_driver_sql,
2812                "SELECT ERROR_ONE FROM I_DONT_EXIST",
2813            )
2814
2815            assert_raises(
2816                MyException1,
2817                conn.execution_options(
2818                    skip_user_error_events=False
2819                ).exec_driver_sql,
2820                "SELECT ERROR_ONE FROM I_DONT_EXIST",
2821            )
2822
2823    def _test_alter_disconnect(self, orig_error, evt_value):
2824        engine = engines.testing_engine()
2825
2826        @event.listens_for(engine, "handle_error")
2827        def evt(ctx):
2828            ctx.is_disconnect = evt_value
2829
2830        with patch.object(
2831            engine.dialect, "is_disconnect", Mock(return_value=orig_error)
2832        ):
2833
2834            with engine.connect() as c:
2835                try:
2836                    c.exec_driver_sql("SELECT x FROM nonexistent")
2837                    assert False
2838                except tsa.exc.StatementError as st:
2839                    eq_(st.connection_invalidated, evt_value)
2840
2841    def test_alter_disconnect_to_true(self):
2842        self._test_alter_disconnect(False, True)
2843        self._test_alter_disconnect(True, True)
2844
2845    def test_alter_disconnect_to_false(self):
2846        self._test_alter_disconnect(True, False)
2847        self._test_alter_disconnect(False, False)
2848
2849    @testing.requires.independent_connections
2850    def _test_alter_invalidate_pool_to_false(self, set_to_false):
2851        orig_error = True
2852
2853        engine = engines.testing_engine()
2854
2855        @event.listens_for(engine, "handle_error")
2856        def evt(ctx):
2857            if set_to_false:
2858                ctx.invalidate_pool_on_disconnect = False
2859
2860        c1, c2, c3 = (
2861            engine.pool.connect(),
2862            engine.pool.connect(),
2863            engine.pool.connect(),
2864        )
2865        crecs = [conn._connection_record for conn in (c1, c2, c3)]
2866        c1.close()
2867        c2.close()
2868        c3.close()
2869
2870        with patch.object(
2871            engine.dialect, "is_disconnect", Mock(return_value=orig_error)
2872        ):
2873
2874            with engine.connect() as c:
2875                target_crec = c.connection._connection_record
2876                try:
2877                    c.exec_driver_sql("SELECT x FROM nonexistent")
2878                    assert False
2879                except tsa.exc.StatementError as st:
2880                    eq_(st.connection_invalidated, True)
2881
2882        for crec in crecs:
2883            if crec is target_crec or not set_to_false:
2884                is_not(crec.dbapi_connection, crec.get_connection())
2885            else:
2886                is_(crec.dbapi_connection, crec.get_connection())
2887
2888    def test_alter_invalidate_pool_to_false(self):
2889        self._test_alter_invalidate_pool_to_false(True)
2890
2891    def test_alter_invalidate_pool_stays_true(self):
2892        self._test_alter_invalidate_pool_to_false(False)
2893
2894    def test_handle_error_event_connect_isolation_level(self):
2895        engine = engines.testing_engine()
2896
2897        class MySpecialException(Exception):
2898            pass
2899
2900        @event.listens_for(engine, "handle_error")
2901        def handle_error(ctx):
2902            raise MySpecialException("failed operation")
2903
2904        ProgrammingError = engine.dialect.dbapi.ProgrammingError
2905        with engine.connect() as conn:
2906            with patch.object(
2907                conn.dialect,
2908                "get_isolation_level",
2909                Mock(side_effect=ProgrammingError("random error")),
2910            ):
2911                assert_raises(MySpecialException, conn.get_isolation_level)
2912
2913    @testing.only_on("sqlite+pysqlite")
2914    def test_cursor_close_resultset_failed_connectionless(self):
2915        engine = engines.testing_engine()
2916
2917        the_conn = []
2918        the_cursor = []
2919
2920        @event.listens_for(engine, "after_cursor_execute")
2921        def go(
2922            connection, cursor, statement, parameters, context, executemany
2923        ):
2924            the_cursor.append(cursor)
2925            the_conn.append(connection)
2926
2927        with mock.patch(
2928            "sqlalchemy.engine.cursor.BaseCursorResult.__init__",
2929            Mock(side_effect=tsa.exc.InvalidRequestError("duplicate col")),
2930        ):
2931            with engine.connect() as conn:
2932                assert_raises(
2933                    tsa.exc.InvalidRequestError,
2934                    conn.execute,
2935                    text("select 1"),
2936                )
2937
2938        # cursor is closed
2939        assert_raises_message(
2940            engine.dialect.dbapi.ProgrammingError,
2941            "Cannot operate on a closed cursor",
2942            the_cursor[0].execute,
2943            "select 1",
2944        )
2945
2946        # connection is closed
2947        assert the_conn[0].closed
2948
2949    @testing.only_on("sqlite+pysqlite")
2950    def test_cursor_close_resultset_failed_explicit(self):
2951        engine = engines.testing_engine()
2952
2953        the_cursor = []
2954
2955        @event.listens_for(engine, "after_cursor_execute")
2956        def go(
2957            connection, cursor, statement, parameters, context, executemany
2958        ):
2959            the_cursor.append(cursor)
2960
2961        conn = engine.connect()
2962
2963        with mock.patch(
2964            "sqlalchemy.engine.cursor.BaseCursorResult.__init__",
2965            Mock(side_effect=tsa.exc.InvalidRequestError("duplicate col")),
2966        ):
2967            assert_raises(
2968                tsa.exc.InvalidRequestError,
2969                conn.execute,
2970                text("select 1"),
2971            )
2972
2973        # cursor is closed
2974        assert_raises_message(
2975            engine.dialect.dbapi.ProgrammingError,
2976            "Cannot operate on a closed cursor",
2977            the_cursor[0].execute,
2978            "select 1",
2979        )
2980
2981        # connection not closed
2982        assert not conn.closed
2983
2984        conn.close()
2985
2986
2987class OnConnectTest(fixtures.TestBase):
2988    __requires__ = ("sqlite",)
2989
2990    def setup_test(self):
2991        e = create_engine("sqlite://")
2992
2993        connection = Mock(get_server_version_info=Mock(return_value="5.0"))
2994
2995        def connect(*args, **kwargs):
2996            return connection
2997
2998        dbapi = Mock(
2999            sqlite_version_info=(99, 9, 9),
3000            version_info=(99, 9, 9),
3001            sqlite_version="99.9.9",
3002            paramstyle="named",
3003            connect=Mock(side_effect=connect),
3004        )
3005
3006        sqlite3 = e.dialect.dbapi
3007        dbapi.Error = (sqlite3.Error,)
3008        dbapi.ProgrammingError = sqlite3.ProgrammingError
3009
3010        self.dbapi = dbapi
3011        self.ProgrammingError = sqlite3.ProgrammingError
3012
3013    def test_wraps_connect_in_dbapi(self):
3014        dbapi = self.dbapi
3015        dbapi.connect = Mock(side_effect=self.ProgrammingError("random error"))
3016        try:
3017            create_engine("sqlite://", module=dbapi).connect()
3018            assert False
3019        except tsa.exc.DBAPIError as de:
3020            assert not de.connection_invalidated
3021
3022    def test_handle_error_event_connect(self):
3023        dbapi = self.dbapi
3024        dbapi.connect = Mock(side_effect=self.ProgrammingError("random error"))
3025
3026        class MySpecialException(Exception):
3027            pass
3028
3029        eng = create_engine("sqlite://", module=dbapi)
3030
3031        @event.listens_for(eng, "handle_error")
3032        def handle_error(ctx):
3033            assert ctx.engine is eng
3034            assert ctx.connection is None
3035            raise MySpecialException("failed operation")
3036
3037        assert_raises(MySpecialException, eng.connect)
3038
3039    def test_handle_error_event_revalidate(self):
3040        dbapi = self.dbapi
3041
3042        class MySpecialException(Exception):
3043            pass
3044
3045        eng = create_engine("sqlite://", module=dbapi, _initialize=False)
3046
3047        @event.listens_for(eng, "handle_error")
3048        def handle_error(ctx):
3049            assert ctx.engine is eng
3050            assert ctx.connection is conn
3051            assert isinstance(
3052                ctx.sqlalchemy_exception, tsa.exc.ProgrammingError
3053            )
3054            raise MySpecialException("failed operation")
3055
3056        conn = eng.connect()
3057        conn.invalidate()
3058
3059        dbapi.connect = Mock(side_effect=self.ProgrammingError("random error"))
3060
3061        assert_raises(MySpecialException, getattr, conn, "connection")
3062
3063    def test_handle_error_event_implicit_revalidate(self):
3064        dbapi = self.dbapi
3065
3066        class MySpecialException(Exception):
3067            pass
3068
3069        eng = create_engine("sqlite://", module=dbapi, _initialize=False)
3070
3071        @event.listens_for(eng, "handle_error")
3072        def handle_error(ctx):
3073            assert ctx.engine is eng
3074            assert ctx.connection is conn
3075            assert isinstance(
3076                ctx.sqlalchemy_exception, tsa.exc.ProgrammingError
3077            )
3078            raise MySpecialException("failed operation")
3079
3080        conn = eng.connect()
3081        conn.invalidate()
3082
3083        dbapi.connect = Mock(side_effect=self.ProgrammingError("random error"))
3084
3085        assert_raises(MySpecialException, conn.execute, select(1))
3086
3087    def test_handle_error_custom_connect(self):
3088        dbapi = self.dbapi
3089
3090        class MySpecialException(Exception):
3091            pass
3092
3093        def custom_connect():
3094            raise self.ProgrammingError("random error")
3095
3096        eng = create_engine("sqlite://", module=dbapi, creator=custom_connect)
3097
3098        @event.listens_for(eng, "handle_error")
3099        def handle_error(ctx):
3100            assert ctx.engine is eng
3101            assert ctx.connection is None
3102            raise MySpecialException("failed operation")
3103
3104        assert_raises(MySpecialException, eng.connect)
3105
3106    def test_handle_error_event_connect_invalidate_flag(self):
3107        dbapi = self.dbapi
3108        dbapi.connect = Mock(
3109            side_effect=self.ProgrammingError(
3110                "Cannot operate on a closed database."
3111            )
3112        )
3113
3114        class MySpecialException(Exception):
3115            pass
3116
3117        eng = create_engine("sqlite://", module=dbapi)
3118
3119        @event.listens_for(eng, "handle_error")
3120        def handle_error(ctx):
3121            assert ctx.is_disconnect
3122            ctx.is_disconnect = False
3123
3124        try:
3125            eng.connect()
3126            assert False
3127        except tsa.exc.DBAPIError as de:
3128            assert not de.connection_invalidated
3129
3130    def test_cant_connect_stay_invalidated(self):
3131        class MySpecialException(Exception):
3132            pass
3133
3134        eng = create_engine("sqlite://")
3135
3136        @event.listens_for(eng, "handle_error")
3137        def handle_error(ctx):
3138            assert ctx.is_disconnect
3139
3140        conn = eng.connect()
3141
3142        conn.invalidate()
3143
3144        eng.pool._creator = Mock(
3145            side_effect=self.ProgrammingError(
3146                "Cannot operate on a closed database."
3147            )
3148        )
3149
3150        try:
3151            conn.connection
3152            assert False
3153        except tsa.exc.DBAPIError:
3154            assert conn.invalidated
3155
3156    def test_dont_touch_non_dbapi_exception_on_connect(self):
3157        dbapi = self.dbapi
3158        dbapi.connect = Mock(side_effect=TypeError("I'm not a DBAPI error"))
3159
3160        e = create_engine("sqlite://", module=dbapi)
3161        e.dialect.is_disconnect = is_disconnect = Mock()
3162        assert_raises_message(TypeError, "I'm not a DBAPI error", e.connect)
3163        eq_(is_disconnect.call_count, 0)
3164
3165    def test_ensure_dialect_does_is_disconnect_no_conn(self):
3166        """test that is_disconnect() doesn't choke if no connection,
3167        cursor given."""
3168        dialect = testing.db.dialect
3169        dbapi = dialect.dbapi
3170        assert not dialect.is_disconnect(
3171            dbapi.OperationalError("test"), None, None
3172        )
3173
3174    def test_invalidate_on_connect(self):
3175        """test that is_disconnect() is called during connect.
3176
3177        interpretation of connection failures are not supported by
3178        every backend.
3179
3180        """
3181        dbapi = self.dbapi
3182        dbapi.connect = Mock(
3183            side_effect=self.ProgrammingError(
3184                "Cannot operate on a closed database."
3185            )
3186        )
3187        e = create_engine("sqlite://", module=dbapi)
3188        try:
3189            e.connect()
3190            assert False
3191        except tsa.exc.DBAPIError as de:
3192            assert de.connection_invalidated
3193
3194    @testing.only_on("sqlite+pysqlite")
3195    def test_initialize_connect_calls(self):
3196        """test for :ticket:`5497`, on_connect not called twice"""
3197
3198        m1 = Mock()
3199        cls_ = testing.db.dialect.__class__
3200
3201        class SomeDialect(cls_):
3202            def initialize(self, connection):
3203                super(SomeDialect, self).initialize(connection)
3204                m1.initialize(connection)
3205
3206            def on_connect(self):
3207                oc = super(SomeDialect, self).on_connect()
3208
3209                def my_on_connect(conn):
3210                    if oc:
3211                        oc(conn)
3212                    m1.on_connect(conn)
3213
3214                return my_on_connect
3215
3216        u1 = Mock(
3217            username=None,
3218            password=None,
3219            host=None,
3220            port=None,
3221            query={},
3222            database=None,
3223            _instantiate_plugins=lambda kw: (u1, [], kw),
3224            _get_entrypoint=Mock(
3225                return_value=Mock(get_dialect_cls=lambda u: SomeDialect)
3226            ),
3227        )
3228        eng = create_engine(u1, poolclass=QueuePool)
3229        # make sure other dialects aren't getting pulled in here
3230        eq_(eng.name, "sqlite")
3231        c = eng.connect()
3232        dbapi_conn_one = c.connection.dbapi_connection
3233        c.close()
3234
3235        eq_(
3236            m1.mock_calls,
3237            [call.on_connect(dbapi_conn_one), call.initialize(mock.ANY)],
3238        )
3239
3240        c = eng.connect()
3241
3242        eq_(
3243            m1.mock_calls,
3244            [call.on_connect(dbapi_conn_one), call.initialize(mock.ANY)],
3245        )
3246
3247        c2 = eng.connect()
3248        dbapi_conn_two = c2.connection.dbapi_connection
3249
3250        is_not(dbapi_conn_one, dbapi_conn_two)
3251
3252        eq_(
3253            m1.mock_calls,
3254            [
3255                call.on_connect(dbapi_conn_one),
3256                call.initialize(mock.ANY),
3257                call.on_connect(dbapi_conn_two),
3258            ],
3259        )
3260
3261        c.close()
3262        c2.close()
3263
3264    @testing.only_on("sqlite+pysqlite")
3265    def test_initialize_connect_race(self):
3266        """test for :ticket:`6337` fixing the regression in :ticket:`5497`,
3267        dialect init is mutexed"""
3268
3269        m1 = []
3270        cls_ = testing.db.dialect.__class__
3271
3272        class SomeDialect(cls_):
3273            def initialize(self, connection):
3274                super(SomeDialect, self).initialize(connection)
3275                m1.append("initialize")
3276
3277            def on_connect(self):
3278                oc = super(SomeDialect, self).on_connect()
3279
3280                def my_on_connect(conn):
3281                    if oc:
3282                        oc(conn)
3283                    m1.append("on_connect")
3284
3285                return my_on_connect
3286
3287        u1 = Mock(
3288            username=None,
3289            password=None,
3290            host=None,
3291            port=None,
3292            query={},
3293            database=None,
3294            _instantiate_plugins=lambda kw: (u1, [], kw),
3295            _get_entrypoint=Mock(
3296                return_value=Mock(get_dialect_cls=lambda u: SomeDialect)
3297            ),
3298        )
3299
3300        for j in range(5):
3301            m1[:] = []
3302            eng = create_engine(
3303                u1,
3304                poolclass=NullPool,
3305                connect_args={"check_same_thread": False},
3306            )
3307
3308            def go():
3309                c = eng.connect()
3310                c.execute(text("select 1"))
3311                c.close()
3312
3313            threads = [threading.Thread(target=go) for i in range(10)]
3314            for t in threads:
3315                t.start()
3316            for t in threads:
3317                t.join()
3318
3319            eq_(m1, ["on_connect", "initialize"] + ["on_connect"] * 9)
3320
3321
3322class DialectEventTest(fixtures.TestBase):
3323    @contextmanager
3324    def _run_test(self, retval):
3325        m1 = Mock()
3326
3327        m1.do_execute.return_value = retval
3328        m1.do_executemany.return_value = retval
3329        m1.do_execute_no_params.return_value = retval
3330        e = engines.testing_engine(options={"_initialize": False})
3331
3332        event.listen(e, "do_execute", m1.do_execute)
3333        event.listen(e, "do_executemany", m1.do_executemany)
3334        event.listen(e, "do_execute_no_params", m1.do_execute_no_params)
3335
3336        e.dialect.do_execute = m1.real_do_execute
3337        e.dialect.do_executemany = m1.real_do_executemany
3338        e.dialect.do_execute_no_params = m1.real_do_execute_no_params
3339
3340        def mock_the_cursor(cursor, *arg):
3341            arg[-1].get_result_proxy = Mock(return_value=Mock(context=arg[-1]))
3342            return retval
3343
3344        m1.real_do_execute.side_effect = (
3345            m1.do_execute.side_effect
3346        ) = mock_the_cursor
3347        m1.real_do_executemany.side_effect = (
3348            m1.do_executemany.side_effect
3349        ) = mock_the_cursor
3350        m1.real_do_execute_no_params.side_effect = (
3351            m1.do_execute_no_params.side_effect
3352        ) = mock_the_cursor
3353
3354        with e.begin() as conn:
3355            yield conn, m1
3356
3357    def _assert(self, retval, m1, m2, mock_calls):
3358        eq_(m1.mock_calls, mock_calls)
3359        if retval:
3360            eq_(m2.mock_calls, [])
3361        else:
3362            eq_(m2.mock_calls, mock_calls)
3363
3364    def _test_do_execute(self, retval):
3365        with self._run_test(retval) as (conn, m1):
3366            result = conn.exec_driver_sql(
3367                "insert into table foo", {"foo": "bar"}
3368            )
3369        self._assert(
3370            retval,
3371            m1.do_execute,
3372            m1.real_do_execute,
3373            [
3374                call(
3375                    result.context.cursor,
3376                    "insert into table foo",
3377                    {"foo": "bar"},
3378                    result.context,
3379                )
3380            ],
3381        )
3382
3383    def _test_do_executemany(self, retval):
3384        with self._run_test(retval) as (conn, m1):
3385            result = conn.exec_driver_sql(
3386                "insert into table foo", [{"foo": "bar"}, {"foo": "bar"}]
3387            )
3388        self._assert(
3389            retval,
3390            m1.do_executemany,
3391            m1.real_do_executemany,
3392            [
3393                call(
3394                    result.context.cursor,
3395                    "insert into table foo",
3396                    [{"foo": "bar"}, {"foo": "bar"}],
3397                    result.context,
3398                )
3399            ],
3400        )
3401
3402    def _test_do_execute_no_params(self, retval):
3403        with self._run_test(retval) as (conn, m1):
3404            result = conn.execution_options(
3405                no_parameters=True
3406            ).exec_driver_sql("insert into table foo")
3407        self._assert(
3408            retval,
3409            m1.do_execute_no_params,
3410            m1.real_do_execute_no_params,
3411            [
3412                call(
3413                    result.context.cursor,
3414                    "insert into table foo",
3415                    result.context,
3416                )
3417            ],
3418        )
3419
3420    def _test_cursor_execute(self, retval):
3421        with self._run_test(retval) as (conn, m1):
3422            dialect = conn.dialect
3423
3424            stmt = "insert into table foo"
3425            params = {"foo": "bar"}
3426            ctx = dialect.execution_ctx_cls._init_statement(
3427                dialect,
3428                conn,
3429                conn.connection,
3430                {},
3431                stmt,
3432                [params],
3433            )
3434
3435            conn._cursor_execute(ctx.cursor, stmt, params, ctx)
3436
3437        self._assert(
3438            retval,
3439            m1.do_execute,
3440            m1.real_do_execute,
3441            [call(ctx.cursor, "insert into table foo", {"foo": "bar"}, ctx)],
3442        )
3443
3444    def test_do_execute_w_replace(self):
3445        self._test_do_execute(True)
3446
3447    def test_do_execute_wo_replace(self):
3448        self._test_do_execute(False)
3449
3450    def test_do_executemany_w_replace(self):
3451        self._test_do_executemany(True)
3452
3453    def test_do_executemany_wo_replace(self):
3454        self._test_do_executemany(False)
3455
3456    def test_do_execute_no_params_w_replace(self):
3457        self._test_do_execute_no_params(True)
3458
3459    def test_do_execute_no_params_wo_replace(self):
3460        self._test_do_execute_no_params(False)
3461
3462    def test_cursor_execute_w_replace(self):
3463        self._test_cursor_execute(True)
3464
3465    def test_cursor_execute_wo_replace(self):
3466        self._test_cursor_execute(False)
3467
3468    def test_connect_replace_params(self):
3469        e = engines.testing_engine(options={"_initialize": False})
3470
3471        @event.listens_for(e, "do_connect")
3472        def evt(dialect, conn_rec, cargs, cparams):
3473            cargs[:] = ["foo", "hoho"]
3474            cparams.clear()
3475            cparams["bar"] = "bat"
3476            conn_rec.info["boom"] = "bap"
3477
3478        m1 = Mock()
3479        e.dialect.connect = m1.real_connect
3480
3481        with e.connect() as conn:
3482            eq_(m1.mock_calls, [call.real_connect("foo", "hoho", bar="bat")])
3483            eq_(conn.info["boom"], "bap")
3484
3485    def test_connect_do_connect(self):
3486        e = engines.testing_engine(options={"_initialize": False})
3487
3488        m1 = Mock()
3489
3490        @event.listens_for(e, "do_connect")
3491        def evt1(dialect, conn_rec, cargs, cparams):
3492            cargs[:] = ["foo", "hoho"]
3493            cparams.clear()
3494            cparams["bar"] = "bat"
3495            conn_rec.info["boom"] = "one"
3496
3497        @event.listens_for(e, "do_connect")
3498        def evt2(dialect, conn_rec, cargs, cparams):
3499            conn_rec.info["bap"] = "two"
3500            return m1.our_connect(cargs, cparams)
3501
3502        with e.connect() as conn:
3503            # called with args
3504            eq_(
3505                m1.mock_calls,
3506                [call.our_connect(["foo", "hoho"], {"bar": "bat"})],
3507            )
3508
3509            eq_(conn.info["boom"], "one")
3510            eq_(conn.info["bap"], "two")
3511
3512            # returned our mock connection
3513            is_(conn.connection.dbapi_connection, m1.our_connect())
3514
3515    def test_connect_do_connect_info_there_after_recycle(self):
3516        # test that info is maintained after the do_connect()
3517        # event for a soft invalidation.
3518
3519        e = engines.testing_engine(options={"_initialize": False})
3520
3521        @event.listens_for(e, "do_connect")
3522        def evt1(dialect, conn_rec, cargs, cparams):
3523            conn_rec.info["boom"] = "one"
3524
3525        conn = e.connect()
3526        eq_(conn.info["boom"], "one")
3527
3528        conn.connection.invalidate(soft=True)
3529        conn.close()
3530        conn = e.connect()
3531        eq_(conn.info["boom"], "one")
3532
3533    def test_connect_do_connect_info_there_after_invalidate(self):
3534        # test that info is maintained after the do_connect()
3535        # event for a hard invalidation.
3536
3537        e = engines.testing_engine(options={"_initialize": False})
3538
3539        @event.listens_for(e, "do_connect")
3540        def evt1(dialect, conn_rec, cargs, cparams):
3541            assert not conn_rec.info
3542            conn_rec.info["boom"] = "one"
3543
3544        conn = e.connect()
3545        eq_(conn.info["boom"], "one")
3546
3547        conn.connection.invalidate()
3548        conn = e.connect()
3549        eq_(conn.info["boom"], "one")
3550
3551
3552class FutureExecuteTest(fixtures.FutureEngineMixin, fixtures.TablesTest):
3553    __backend__ = True
3554
3555    @classmethod
3556    def define_tables(cls, metadata):
3557        Table(
3558            "users",
3559            metadata,
3560            Column("user_id", INT, primary_key=True, autoincrement=False),
3561            Column("user_name", VARCHAR(20)),
3562            test_needs_acid=True,
3563        )
3564        Table(
3565            "users_autoinc",
3566            metadata,
3567            Column(
3568                "user_id", INT, primary_key=True, test_needs_autoincrement=True
3569            ),
3570            Column("user_name", VARCHAR(20)),
3571            test_needs_acid=True,
3572        )
3573
3574    def test_non_dict_mapping(self, connection):
3575        """ensure arbitrary Mapping works for execute()"""
3576
3577        class NotADict(collections_abc.Mapping):
3578            def __init__(self, _data):
3579                self._data = _data
3580
3581            def __iter__(self):
3582                return iter(self._data)
3583
3584            def __len__(self):
3585                return len(self._data)
3586
3587            def __getitem__(self, key):
3588                return self._data[key]
3589
3590            def keys(self):
3591                return self._data.keys()
3592
3593        nd = NotADict({"a": 10, "b": 15})
3594        eq_(dict(nd), {"a": 10, "b": 15})
3595
3596        result = connection.execute(
3597            select(
3598                bindparam("a", type_=Integer), bindparam("b", type_=Integer)
3599            ),
3600            nd,
3601        )
3602        eq_(result.first(), (10, 15))
3603
3604    def test_row_works_as_mapping(self, connection):
3605        """ensure the RowMapping object works as a parameter dictionary for
3606        execute."""
3607
3608        result = connection.execute(
3609            select(literal(10).label("a"), literal(15).label("b"))
3610        )
3611        row = result.first()
3612        eq_(row, (10, 15))
3613        eq_(row._mapping, {"a": 10, "b": 15})
3614
3615        result = connection.execute(
3616            select(
3617                bindparam("a", type_=Integer).label("a"),
3618                bindparam("b", type_=Integer).label("b"),
3619            ),
3620            row._mapping,
3621        )
3622        row = result.first()
3623        eq_(row, (10, 15))
3624        eq_(row._mapping, {"a": 10, "b": 15})
3625
3626    @testing.combinations(
3627        ({}, {}, {}),
3628        ({"a": "b"}, {}, {"a": "b"}),
3629        ({"a": "b", "d": "e"}, {"a": "c"}, {"a": "c", "d": "e"}),
3630        argnames="conn_opts, exec_opts, expected",
3631    )
3632    def test_execution_opts_per_invoke(
3633        self, connection, conn_opts, exec_opts, expected
3634    ):
3635        opts = []
3636
3637        @event.listens_for(connection, "before_cursor_execute")
3638        def before_cursor_execute(
3639            conn, cursor, statement, parameters, context, executemany
3640        ):
3641            opts.append(context.execution_options)
3642
3643        if conn_opts:
3644            connection = connection.execution_options(**conn_opts)
3645
3646        if exec_opts:
3647            connection.execute(select(1), execution_options=exec_opts)
3648        else:
3649            connection.execute(select(1))
3650
3651        eq_(opts, [expected])
3652
3653    @testing.combinations(
3654        ({}, {}, {}, {}),
3655        ({}, {"a": "b"}, {}, {"a": "b"}),
3656        ({}, {"a": "b", "d": "e"}, {"a": "c"}, {"a": "c", "d": "e"}),
3657        (
3658            {"q": "z", "p": "r"},
3659            {"a": "b", "p": "x", "d": "e"},
3660            {"a": "c"},
3661            {"q": "z", "p": "x", "a": "c", "d": "e"},
3662        ),
3663        argnames="stmt_opts, conn_opts, exec_opts, expected",
3664    )
3665    def test_execution_opts_per_invoke_execute_events(
3666        self, connection, stmt_opts, conn_opts, exec_opts, expected
3667    ):
3668        opts = []
3669
3670        @event.listens_for(connection, "before_execute")
3671        def before_execute(
3672            conn, clauseelement, multiparams, params, execution_options
3673        ):
3674            opts.append(("before", execution_options))
3675
3676        @event.listens_for(connection, "after_execute")
3677        def after_execute(
3678            conn,
3679            clauseelement,
3680            multiparams,
3681            params,
3682            execution_options,
3683            result,
3684        ):
3685            opts.append(("after", execution_options))
3686
3687        stmt = select(1)
3688
3689        if stmt_opts:
3690            stmt = stmt.execution_options(**stmt_opts)
3691
3692        if conn_opts:
3693            connection = connection.execution_options(**conn_opts)
3694
3695        if exec_opts:
3696            connection.execute(stmt, execution_options=exec_opts)
3697        else:
3698            connection.execute(stmt)
3699
3700        eq_(opts, [("before", expected), ("after", expected)])
3701
3702    def test_no_branching(self, connection):
3703        with testing.expect_deprecated(
3704            r"The Connection.connect\(\) method is considered legacy"
3705        ):
3706            assert_raises_message(
3707                NotImplementedError,
3708                "sqlalchemy.future.Connection does not support "
3709                "'branching' of new connections.",
3710                connection.connect,
3711            )
3712
3713
3714class SetInputSizesTest(fixtures.TablesTest):
3715    __backend__ = True
3716
3717    __requires__ = ("independent_connections",)
3718
3719    @classmethod
3720    def define_tables(cls, metadata):
3721        Table(
3722            "users",
3723            metadata,
3724            Column("user_id", INT, primary_key=True, autoincrement=False),
3725            Column("user_name", VARCHAR(20)),
3726        )
3727
3728    @testing.fixture
3729    def input_sizes_fixture(self, testing_engine):
3730        canary = mock.Mock()
3731
3732        def do_set_input_sizes(cursor, list_of_tuples, context):
3733            if not engine.dialect.positional:
3734                # sort by "user_id", "user_name", or otherwise
3735                # param name for a non-positional dialect, so that we can
3736                # confirm the ordering.  mostly a py2 thing probably can't
3737                # occur on py3.6+ since we are passing dictionaries with
3738                # "user_id", "user_name"
3739                list_of_tuples = sorted(
3740                    list_of_tuples, key=lambda elem: elem[0]
3741                )
3742            canary.do_set_input_sizes(cursor, list_of_tuples, context)
3743
3744        def pre_exec(self):
3745            self.translate_set_input_sizes = None
3746            self.include_set_input_sizes = None
3747            self.exclude_set_input_sizes = None
3748
3749        engine = testing_engine()
3750        engine.connect().close()
3751
3752        # the idea of this test is we fully replace the dialect
3753        # do_set_input_sizes with a mock, and we can then intercept
3754        # the setting passed to the dialect.  the test table uses very
3755        # "safe" datatypes so that the DBAPI does not actually need
3756        # setinputsizes() called in order to work.
3757
3758        with mock.patch.object(
3759            engine.dialect, "use_setinputsizes", True
3760        ), mock.patch.object(
3761            engine.dialect, "do_set_input_sizes", do_set_input_sizes
3762        ), mock.patch.object(
3763            engine.dialect.execution_ctx_cls, "pre_exec", pre_exec
3764        ):
3765            yield engine, canary
3766
3767    def test_set_input_sizes_no_event(self, input_sizes_fixture):
3768        engine, canary = input_sizes_fixture
3769
3770        with engine.begin() as conn:
3771            conn.execute(
3772                self.tables.users.insert(),
3773                [
3774                    {"user_id": 1, "user_name": "n1"},
3775                    {"user_id": 2, "user_name": "n2"},
3776                ],
3777            )
3778
3779        eq_(
3780            canary.mock_calls,
3781            [
3782                call.do_set_input_sizes(
3783                    mock.ANY,
3784                    [
3785                        (
3786                            "user_id",
3787                            mock.ANY,
3788                            testing.eq_type_affinity(Integer),
3789                        ),
3790                        (
3791                            "user_name",
3792                            mock.ANY,
3793                            testing.eq_type_affinity(String),
3794                        ),
3795                    ],
3796                    mock.ANY,
3797                )
3798            ],
3799        )
3800
3801    def test_set_input_sizes_expanding_param(self, input_sizes_fixture):
3802        engine, canary = input_sizes_fixture
3803
3804        with engine.connect() as conn:
3805            conn.execute(
3806                select(self.tables.users).where(
3807                    self.tables.users.c.user_name.in_(["x", "y", "z"])
3808                )
3809            )
3810
3811        eq_(
3812            canary.mock_calls,
3813            [
3814                call.do_set_input_sizes(
3815                    mock.ANY,
3816                    [
3817                        (
3818                            "user_name_1_1",
3819                            mock.ANY,
3820                            testing.eq_type_affinity(String),
3821                        ),
3822                        (
3823                            "user_name_1_2",
3824                            mock.ANY,
3825                            testing.eq_type_affinity(String),
3826                        ),
3827                        (
3828                            "user_name_1_3",
3829                            mock.ANY,
3830                            testing.eq_type_affinity(String),
3831                        ),
3832                    ],
3833                    mock.ANY,
3834                )
3835            ],
3836        )
3837
3838    @testing.requires.tuple_in
3839    def test_set_input_sizes_expanding_tuple_param(self, input_sizes_fixture):
3840        engine, canary = input_sizes_fixture
3841
3842        from sqlalchemy import tuple_
3843
3844        with engine.connect() as conn:
3845            conn.execute(
3846                select(self.tables.users).where(
3847                    tuple_(
3848                        self.tables.users.c.user_id,
3849                        self.tables.users.c.user_name,
3850                    ).in_([(1, "x"), (2, "y")])
3851                )
3852            )
3853
3854        eq_(
3855            canary.mock_calls,
3856            [
3857                call.do_set_input_sizes(
3858                    mock.ANY,
3859                    [
3860                        (
3861                            "param_1_1_1",
3862                            mock.ANY,
3863                            testing.eq_type_affinity(Integer),
3864                        ),
3865                        (
3866                            "param_1_1_2",
3867                            mock.ANY,
3868                            testing.eq_type_affinity(String),
3869                        ),
3870                        (
3871                            "param_1_2_1",
3872                            mock.ANY,
3873                            testing.eq_type_affinity(Integer),
3874                        ),
3875                        (
3876                            "param_1_2_2",
3877                            mock.ANY,
3878                            testing.eq_type_affinity(String),
3879                        ),
3880                    ],
3881                    mock.ANY,
3882                )
3883            ],
3884        )
3885
3886    def test_set_input_sizes_event(self, input_sizes_fixture):
3887        engine, canary = input_sizes_fixture
3888
3889        SPECIAL_STRING = mock.Mock()
3890
3891        @event.listens_for(engine, "do_setinputsizes")
3892        def do_setinputsizes(
3893            inputsizes, cursor, statement, parameters, context
3894        ):
3895            for k in inputsizes:
3896                if k.type._type_affinity is String:
3897                    inputsizes[k] = (
3898                        SPECIAL_STRING,
3899                        None,
3900                        0,
3901                    )
3902
3903        with engine.begin() as conn:
3904            conn.execute(
3905                self.tables.users.insert(),
3906                [
3907                    {"user_id": 1, "user_name": "n1"},
3908                    {"user_id": 2, "user_name": "n2"},
3909                ],
3910            )
3911
3912        eq_(
3913            canary.mock_calls,
3914            [
3915                call.do_set_input_sizes(
3916                    mock.ANY,
3917                    [
3918                        (
3919                            "user_id",
3920                            mock.ANY,
3921                            testing.eq_type_affinity(Integer),
3922                        ),
3923                        (
3924                            "user_name",
3925                            (SPECIAL_STRING, None, 0),
3926                            testing.eq_type_affinity(String),
3927                        ),
3928                    ],
3929                    mock.ANY,
3930                )
3931            ],
3932        )
3933
3934
3935class DialectDoesntSupportCachingTest(fixtures.TestBase):
3936    """test the opt-in caching flag added in :ticket:`6184`."""
3937
3938    __only_on__ = "sqlite+pysqlite"
3939
3940    __requires__ = ("sqlite_memory",)
3941
3942    @testing.fixture()
3943    def sqlite_no_cache_dialect(self, testing_engine):
3944        from sqlalchemy.dialects.sqlite.pysqlite import SQLiteDialect_pysqlite
3945        from sqlalchemy.dialects.sqlite.base import SQLiteCompiler
3946        from sqlalchemy.sql import visitors
3947
3948        class MyCompiler(SQLiteCompiler):
3949            def translate_select_structure(self, select_stmt, **kwargs):
3950                select = select_stmt
3951
3952                if not getattr(select, "_mydialect_visit", None):
3953                    select = visitors.cloned_traverse(select_stmt, {}, {})
3954                    if select._limit_clause is not None:
3955                        # create a bindparam with a fixed name and hardcode
3956                        # it to the given limit.  this breaks caching.
3957                        select._limit_clause = bindparam(
3958                            "limit", value=select._limit, literal_execute=True
3959                        )
3960
3961                    select._mydialect_visit = True
3962
3963                return select
3964
3965        class MyDialect(SQLiteDialect_pysqlite):
3966            statement_compiler = MyCompiler
3967
3968        from sqlalchemy.dialects import registry
3969
3970        def go(name):
3971            return MyDialect
3972
3973        with mock.patch.object(registry, "load", go):
3974            eng = testing_engine()
3975            yield eng
3976
3977    @testing.fixture
3978    def data_fixture(self, sqlite_no_cache_dialect):
3979        m = MetaData()
3980        t = Table("t1", m, Column("x", Integer))
3981        with sqlite_no_cache_dialect.begin() as conn:
3982            t.create(conn)
3983            conn.execute(t.insert(), [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}])
3984
3985        return t
3986
3987    def test_no_cache(self, sqlite_no_cache_dialect, data_fixture):
3988        eng = sqlite_no_cache_dialect
3989
3990        def go(lim):
3991            with eng.connect() as conn:
3992                result = conn.execute(
3993                    select(data_fixture).order_by(data_fixture.c.x).limit(lim)
3994                )
3995                return result
3996
3997        r1 = go(2)
3998        r2 = go(3)
3999
4000        eq_(r1.all(), [(1,), (2,)])
4001        eq_(r2.all(), [(1,), (2,), (3,)])
4002
4003    def test_it_caches(self, sqlite_no_cache_dialect, data_fixture):
4004        eng = sqlite_no_cache_dialect
4005        eng.dialect.__class__.supports_statement_cache = True
4006        del eng.dialect.__dict__["_supports_statement_cache"]
4007
4008        def go(lim):
4009            with eng.connect() as conn:
4010                result = conn.execute(
4011                    select(data_fixture).order_by(data_fixture.c.x).limit(lim)
4012                )
4013                return result
4014
4015        r1 = go(2)
4016        r2 = go(3)
4017
4018        eq_(r1.all(), [(1,), (2,)])
4019
4020        # wrong answer
4021        eq_(
4022            r2.all(),
4023            [
4024                (1,),
4025                (2,),
4026            ],
4027        )
4028