1# coding: utf-8 2 3from contextlib import contextmanager 4import re 5import threading 6import weakref 7 8import sqlalchemy as tsa 9from sqlalchemy import bindparam 10from sqlalchemy import create_engine 11from sqlalchemy import create_mock_engine 12from sqlalchemy import event 13from sqlalchemy import func 14from sqlalchemy import inspect 15from sqlalchemy import INT 16from sqlalchemy import Integer 17from sqlalchemy import LargeBinary 18from sqlalchemy import MetaData 19from sqlalchemy import select 20from sqlalchemy import Sequence 21from sqlalchemy import String 22from sqlalchemy import testing 23from sqlalchemy import text 24from sqlalchemy import TypeDecorator 25from sqlalchemy import util 26from sqlalchemy import VARCHAR 27from sqlalchemy.engine import default 28from sqlalchemy.engine.base import Connection 29from sqlalchemy.engine.base import Engine 30from sqlalchemy.pool import NullPool 31from sqlalchemy.pool import QueuePool 32from sqlalchemy.sql import column 33from sqlalchemy.sql import literal 34from sqlalchemy.sql.elements import literal_column 35from sqlalchemy.testing import assert_raises 36from sqlalchemy.testing import assert_raises_message 37from sqlalchemy.testing import config 38from sqlalchemy.testing import engines 39from sqlalchemy.testing import eq_ 40from sqlalchemy.testing import expect_raises_message 41from sqlalchemy.testing import expect_warnings 42from sqlalchemy.testing import fixtures 43from sqlalchemy.testing import is_ 44from sqlalchemy.testing import is_false 45from sqlalchemy.testing import is_not 46from sqlalchemy.testing import is_true 47from sqlalchemy.testing import mock 48from sqlalchemy.testing.assertions import expect_deprecated 49from sqlalchemy.testing.assertsql import CompiledSQL 50from sqlalchemy.testing.mock import call 51from sqlalchemy.testing.mock import Mock 52from sqlalchemy.testing.mock import patch 53from sqlalchemy.testing.schema import Column 54from sqlalchemy.testing.schema import Table 55from sqlalchemy.testing.util import gc_collect 56from sqlalchemy.testing.util import picklers 57from sqlalchemy.util import collections_abc 58 59 60class SomeException(Exception): 61 pass 62 63 64class Foo(object): 65 def __str__(self): 66 return "foo" 67 68 def __unicode__(self): 69 return util.u("fóó") 70 71 72class ExecuteTest(fixtures.TablesTest): 73 __backend__ = True 74 75 @classmethod 76 def define_tables(cls, metadata): 77 Table( 78 "users", 79 metadata, 80 Column("user_id", INT, primary_key=True, autoincrement=False), 81 Column("user_name", VARCHAR(20)), 82 ) 83 Table( 84 "users_autoinc", 85 metadata, 86 Column( 87 "user_id", INT, primary_key=True, test_needs_autoincrement=True 88 ), 89 Column("user_name", VARCHAR(20)), 90 ) 91 92 def test_no_params_option(self): 93 stmt = ( 94 "SELECT '%'" 95 + testing.db.dialect.statement_compiler( 96 testing.db.dialect, None 97 ).default_from() 98 ) 99 100 with testing.db.connect() as conn: 101 result = ( 102 conn.execution_options(no_parameters=True) 103 .exec_driver_sql(stmt) 104 .scalar() 105 ) 106 eq_(result, "%") 107 108 def test_raw_positional_invalid(self, connection): 109 assert_raises_message( 110 tsa.exc.ArgumentError, 111 "List argument must consist only of tuples or dictionaries", 112 connection.exec_driver_sql, 113 "insert into users (user_id, user_name) " "values (?, ?)", 114 [2, "fred"], 115 ) 116 117 assert_raises_message( 118 tsa.exc.ArgumentError, 119 "List argument must consist only of tuples or dictionaries", 120 connection.exec_driver_sql, 121 "insert into users (user_id, user_name) " "values (?, ?)", 122 [[3, "ed"], [4, "horse"]], 123 ) 124 125 def test_raw_named_invalid(self, connection): 126 # this is awkward b.c. this is just testing if regular Python 127 # is raising TypeError if they happened to send arguments that 128 # look like the legacy ones which also happen to conflict with 129 # the positional signature for the method. some combinations 130 # can get through and fail differently 131 assert_raises( 132 TypeError, 133 connection.exec_driver_sql, 134 "insert into users (user_id, user_name) " 135 "values (%(id)s, %(name)s)", 136 {"id": 2, "name": "ed"}, 137 {"id": 3, "name": "horse"}, 138 {"id": 4, "name": "horse"}, 139 ) 140 assert_raises( 141 TypeError, 142 connection.exec_driver_sql, 143 "insert into users (user_id, user_name) " 144 "values (%(id)s, %(name)s)", 145 id=4, 146 name="sally", 147 ) 148 149 @testing.requires.qmark_paramstyle 150 def test_raw_qmark(self, connection): 151 conn = connection 152 conn.exec_driver_sql( 153 "insert into users (user_id, user_name) " "values (?, ?)", 154 (1, "jack"), 155 ) 156 conn.exec_driver_sql( 157 "insert into users (user_id, user_name) " "values (?, ?)", 158 (2, "fred"), 159 ) 160 conn.exec_driver_sql( 161 "insert into users (user_id, user_name) " "values (?, ?)", 162 [(3, "ed"), (4, "horse")], 163 ) 164 conn.exec_driver_sql( 165 "insert into users (user_id, user_name) " "values (?, ?)", 166 [(5, "barney"), (6, "donkey")], 167 ) 168 conn.exec_driver_sql( 169 "insert into users (user_id, user_name) " "values (?, ?)", 170 (7, "sally"), 171 ) 172 res = conn.exec_driver_sql("select * from users order by user_id") 173 assert res.fetchall() == [ 174 (1, "jack"), 175 (2, "fred"), 176 (3, "ed"), 177 (4, "horse"), 178 (5, "barney"), 179 (6, "donkey"), 180 (7, "sally"), 181 ] 182 183 res = conn.exec_driver_sql( 184 "select * from users where user_name=?", ("jack",) 185 ) 186 assert res.fetchall() == [(1, "jack")] 187 188 @testing.requires.format_paramstyle 189 def test_raw_sprintf(self, connection): 190 conn = connection 191 conn.exec_driver_sql( 192 "insert into users (user_id, user_name) " "values (%s, %s)", 193 (1, "jack"), 194 ) 195 conn.exec_driver_sql( 196 "insert into users (user_id, user_name) " "values (%s, %s)", 197 [(2, "ed"), (3, "horse")], 198 ) 199 conn.exec_driver_sql( 200 "insert into users (user_id, user_name) " "values (%s, %s)", 201 (4, "sally"), 202 ) 203 conn.exec_driver_sql("insert into users (user_id) values (%s)", (5,)) 204 res = conn.exec_driver_sql("select * from users order by user_id") 205 assert res.fetchall() == [ 206 (1, "jack"), 207 (2, "ed"), 208 (3, "horse"), 209 (4, "sally"), 210 (5, None), 211 ] 212 213 res = conn.exec_driver_sql( 214 "select * from users where user_name=%s", ("jack",) 215 ) 216 assert res.fetchall() == [(1, "jack")] 217 218 @testing.requires.pyformat_paramstyle 219 def test_raw_python(self, connection): 220 conn = connection 221 conn.exec_driver_sql( 222 "insert into users (user_id, user_name) " 223 "values (%(id)s, %(name)s)", 224 {"id": 1, "name": "jack"}, 225 ) 226 conn.exec_driver_sql( 227 "insert into users (user_id, user_name) " 228 "values (%(id)s, %(name)s)", 229 [{"id": 2, "name": "ed"}, {"id": 3, "name": "horse"}], 230 ) 231 conn.exec_driver_sql( 232 "insert into users (user_id, user_name) " 233 "values (%(id)s, %(name)s)", 234 dict(id=4, name="sally"), 235 ) 236 res = conn.exec_driver_sql("select * from users order by user_id") 237 assert res.fetchall() == [ 238 (1, "jack"), 239 (2, "ed"), 240 (3, "horse"), 241 (4, "sally"), 242 ] 243 244 @testing.requires.named_paramstyle 245 def test_raw_named(self, connection): 246 conn = connection 247 conn.exec_driver_sql( 248 "insert into users (user_id, user_name) " "values (:id, :name)", 249 {"id": 1, "name": "jack"}, 250 ) 251 conn.exec_driver_sql( 252 "insert into users (user_id, user_name) " "values (:id, :name)", 253 [{"id": 2, "name": "ed"}, {"id": 3, "name": "horse"}], 254 ) 255 conn.exec_driver_sql( 256 "insert into users (user_id, user_name) " "values (:id, :name)", 257 {"id": 4, "name": "sally"}, 258 ) 259 res = conn.exec_driver_sql("select * from users order by user_id") 260 assert res.fetchall() == [ 261 (1, "jack"), 262 (2, "ed"), 263 (3, "horse"), 264 (4, "sally"), 265 ] 266 267 def test_non_dict_mapping(self, connection): 268 """ensure arbitrary Mapping works for execute()""" 269 270 class NotADict(collections_abc.Mapping): 271 def __init__(self, _data): 272 self._data = _data 273 274 def __iter__(self): 275 return iter(self._data) 276 277 def __len__(self): 278 return len(self._data) 279 280 def __getitem__(self, key): 281 return self._data[key] 282 283 def keys(self): 284 return self._data.keys() 285 286 nd = NotADict({"a": 10, "b": 15}) 287 eq_(dict(nd), {"a": 10, "b": 15}) 288 289 result = connection.execute( 290 select( 291 bindparam("a", type_=Integer), bindparam("b", type_=Integer) 292 ), 293 nd, 294 ) 295 eq_(result.first(), (10, 15)) 296 297 def test_row_works_as_mapping(self, connection): 298 """ensure the RowMapping object works as a parameter dictionary for 299 execute.""" 300 301 result = connection.execute( 302 select(literal(10).label("a"), literal(15).label("b")) 303 ) 304 row = result.first() 305 eq_(row, (10, 15)) 306 eq_(row._mapping, {"a": 10, "b": 15}) 307 308 result = connection.execute( 309 select( 310 bindparam("a", type_=Integer).label("a"), 311 bindparam("b", type_=Integer).label("b"), 312 ), 313 row._mapping, 314 ) 315 row = result.first() 316 eq_(row, (10, 15)) 317 eq_(row._mapping, {"a": 10, "b": 15}) 318 319 def test_dialect_has_table_assertion(self): 320 with expect_raises_message( 321 tsa.exc.ArgumentError, 322 r"The argument passed to Dialect.has_table\(\) should be a", 323 ): 324 testing.db.dialect.has_table(testing.db, "some_table") 325 326 def test_exception_wrapping_dbapi(self): 327 with testing.db.connect() as conn: 328 # engine does not have exec_driver_sql 329 assert_raises_message( 330 tsa.exc.DBAPIError, 331 r"not_a_valid_statement", 332 conn.exec_driver_sql, 333 "not_a_valid_statement", 334 ) 335 336 @testing.requires.sqlite 337 def test_exception_wrapping_non_dbapi_error(self): 338 e = create_engine("sqlite://") 339 e.dialect.is_disconnect = is_disconnect = Mock() 340 341 with e.connect() as c: 342 c.connection.cursor = Mock( 343 return_value=Mock( 344 execute=Mock( 345 side_effect=TypeError("I'm not a DBAPI error") 346 ) 347 ) 348 ) 349 350 assert_raises_message( 351 TypeError, 352 "I'm not a DBAPI error", 353 c.exec_driver_sql, 354 "select ", 355 ) 356 eq_(is_disconnect.call_count, 0) 357 358 def test_exception_wrapping_non_standard_dbapi_error(self): 359 class DBAPIError(Exception): 360 pass 361 362 class OperationalError(DBAPIError): 363 pass 364 365 class NonStandardException(OperationalError): 366 pass 367 368 # TODO: this test is assuming too much of arbitrary dialects and would 369 # be better suited tested against a single mock dialect that does not 370 # have any special behaviors 371 with patch.object( 372 testing.db.dialect, "dbapi", Mock(Error=DBAPIError) 373 ), patch.object( 374 testing.db.dialect, "is_disconnect", lambda *arg: False 375 ), patch.object( 376 testing.db.dialect, 377 "do_execute", 378 Mock(side_effect=NonStandardException), 379 ), patch.object( 380 testing.db.dialect.execution_ctx_cls, 381 "handle_dbapi_exception", 382 Mock(), 383 ): 384 with testing.db.connect() as conn: 385 assert_raises( 386 tsa.exc.OperationalError, conn.exec_driver_sql, "select 1" 387 ) 388 389 def test_exception_wrapping_non_dbapi_statement(self): 390 class MyType(TypeDecorator): 391 impl = Integer 392 cache_ok = True 393 394 def process_bind_param(self, value, dialect): 395 raise SomeException("nope") 396 397 def _go(conn): 398 assert_raises_message( 399 tsa.exc.StatementError, 400 r"\(.*.SomeException\) " r"nope\n\[SQL\: u?SELECT 1 ", 401 conn.execute, 402 select(1).where(column("foo") == literal("bar", MyType())), 403 ) 404 405 with testing.db.connect() as conn: 406 _go(conn) 407 408 def test_not_an_executable(self): 409 for obj in ( 410 Table("foo", MetaData(), Column("x", Integer)), 411 Column("x", Integer), 412 tsa.and_(True), 413 tsa.and_(True).compile(), 414 column("foo"), 415 column("foo").compile(), 416 select(1).cte(), 417 # select(1).subquery(), 418 MetaData(), 419 Integer(), 420 tsa.Index(name="foo"), 421 tsa.UniqueConstraint("x"), 422 ): 423 with testing.db.connect() as conn: 424 assert_raises_message( 425 tsa.exc.ObjectNotExecutableError, 426 "Not an executable object", 427 conn.execute, 428 obj, 429 ) 430 431 def test_subquery_exec_warning(self): 432 for obj in (select(1).alias(), select(1).subquery()): 433 with testing.db.connect() as conn: 434 with expect_deprecated( 435 "Executing a subquery object is deprecated and will " 436 "raise ObjectNotExecutableError" 437 ): 438 eq_(conn.execute(obj).scalar(), 1) 439 440 def test_stmt_exception_bytestring_raised(self): 441 name = util.u("méil") 442 users = self.tables.users 443 with testing.db.connect() as conn: 444 assert_raises_message( 445 tsa.exc.StatementError, 446 util.u( 447 "A value is required for bind parameter 'uname'\n" 448 r".*SELECT users.user_name AS .m\xe9il." 449 ) 450 if util.py2k 451 else util.u( 452 "A value is required for bind parameter 'uname'\n" 453 ".*SELECT users.user_name AS .méil." 454 ), 455 conn.execute, 456 select(users.c.user_name.label(name)).where( 457 users.c.user_name == bindparam("uname") 458 ), 459 {"uname_incorrect": "foo"}, 460 ) 461 462 def test_stmt_exception_bytestring_utf8(self): 463 # uncommon case for Py3K, bytestring object passed 464 # as the error message 465 message = util.u("some message méil").encode("utf-8") 466 467 err = tsa.exc.SQLAlchemyError(message) 468 if util.py2k: 469 # string passes it through 470 eq_(str(err), message) 471 472 # unicode accessor decodes to utf-8 473 eq_(unicode(err), util.u("some message méil")) # noqa F821 474 else: 475 eq_(str(err), util.u("some message méil")) 476 477 def test_stmt_exception_bytestring_latin1(self): 478 # uncommon case for Py3K, bytestring object passed 479 # as the error message 480 message = util.u("some message méil").encode("latin-1") 481 482 err = tsa.exc.SQLAlchemyError(message) 483 if util.py2k: 484 # string passes it through 485 eq_(str(err), message) 486 487 # unicode accessor decodes to utf-8 488 eq_(unicode(err), util.u("some message m\\xe9il")) # noqa F821 489 else: 490 eq_(str(err), util.u("some message m\\xe9il")) 491 492 def test_stmt_exception_unicode_hook_unicode(self): 493 # uncommon case for Py2K, Unicode object passed 494 # as the error message 495 message = util.u("some message méil") 496 497 err = tsa.exc.SQLAlchemyError(message) 498 if util.py2k: 499 eq_(unicode(err), util.u("some message méil")) # noqa F821 500 else: 501 eq_(str(err), util.u("some message méil")) 502 503 def test_stmt_exception_object_arg(self): 504 err = tsa.exc.SQLAlchemyError(Foo()) 505 eq_(str(err), "foo") 506 507 if util.py2k: 508 eq_(unicode(err), util.u("fóó")) # noqa F821 509 510 def test_stmt_exception_str_multi_args(self): 511 err = tsa.exc.SQLAlchemyError("some message", 206) 512 eq_(str(err), "('some message', 206)") 513 514 def test_stmt_exception_str_multi_args_bytestring(self): 515 message = util.u("some message méil").encode("utf-8") 516 517 err = tsa.exc.SQLAlchemyError(message, 206) 518 eq_(str(err), str((message, 206))) 519 520 def test_stmt_exception_str_multi_args_unicode(self): 521 message = util.u("some message méil") 522 523 err = tsa.exc.SQLAlchemyError(message, 206) 524 eq_(str(err), str((message, 206))) 525 526 def test_stmt_exception_pickleable_no_dbapi(self): 527 self._test_stmt_exception_pickleable(Exception("hello world")) 528 529 @testing.crashes( 530 "postgresql+psycopg2", 531 "Older versions don't support cursor pickling, newer ones do", 532 ) 533 @testing.fails_on( 534 "mysql+oursql", 535 "Exception doesn't come back exactly the same from pickle", 536 ) 537 @testing.fails_on( 538 "mysql+mysqlconnector", 539 "Exception doesn't come back exactly the same from pickle", 540 ) 541 @testing.fails_on( 542 "oracle+cx_oracle", 543 "cx_oracle exception seems to be having " "some issue with pickling", 544 ) 545 def test_stmt_exception_pickleable_plus_dbapi(self): 546 raw = testing.db.raw_connection() 547 the_orig = None 548 try: 549 try: 550 cursor = raw.cursor() 551 cursor.execute("SELECTINCORRECT") 552 except testing.db.dialect.dbapi.Error as orig: 553 # py3k has "orig" in local scope... 554 the_orig = orig 555 finally: 556 raw.close() 557 self._test_stmt_exception_pickleable(the_orig) 558 559 def _test_stmt_exception_pickleable(self, orig): 560 for sa_exc in ( 561 tsa.exc.StatementError( 562 "some error", 563 "select * from table", 564 {"foo": "bar"}, 565 orig, 566 False, 567 ), 568 tsa.exc.InterfaceError( 569 "select * from table", {"foo": "bar"}, orig, True 570 ), 571 tsa.exc.NoReferencedTableError("message", "tname"), 572 tsa.exc.NoReferencedColumnError("message", "tname", "cname"), 573 tsa.exc.CircularDependencyError( 574 "some message", [1, 2, 3], [(1, 2), (3, 4)] 575 ), 576 ): 577 for loads, dumps in picklers(): 578 repickled = loads(dumps(sa_exc)) 579 eq_(repickled.args[0], sa_exc.args[0]) 580 if isinstance(sa_exc, tsa.exc.StatementError): 581 eq_(repickled.params, {"foo": "bar"}) 582 eq_(repickled.statement, sa_exc.statement) 583 if hasattr(sa_exc, "connection_invalidated"): 584 eq_( 585 repickled.connection_invalidated, 586 sa_exc.connection_invalidated, 587 ) 588 eq_(repickled.orig.args[0], orig.args[0]) 589 590 def test_dont_wrap_mixin(self): 591 class MyException(Exception, tsa.exc.DontWrapMixin): 592 pass 593 594 class MyType(TypeDecorator): 595 impl = Integer 596 cache_ok = True 597 598 def process_bind_param(self, value, dialect): 599 raise MyException("nope") 600 601 def _go(conn): 602 assert_raises_message( 603 MyException, 604 "nope", 605 conn.execute, 606 select(1).where(column("foo") == literal("bar", MyType())), 607 ) 608 609 conn = testing.db.connect() 610 try: 611 _go(conn) 612 finally: 613 conn.close() 614 615 def test_empty_insert(self, connection): 616 """test that execute() interprets [] as a list with no params""" 617 users_autoinc = self.tables.users_autoinc 618 619 connection.execute( 620 users_autoinc.insert().values(user_name=bindparam("name", None)), 621 [], 622 ) 623 eq_(connection.execute(users_autoinc.select()).fetchall(), [(1, None)]) 624 625 @testing.only_on("sqlite") 626 def test_execute_compiled_favors_compiled_paramstyle(self): 627 users = self.tables.users 628 629 with patch.object(testing.db.dialect, "do_execute") as do_exec: 630 stmt = users.update().values(user_id=1, user_name="foo") 631 632 d1 = default.DefaultDialect(paramstyle="format") 633 d2 = default.DefaultDialect(paramstyle="pyformat") 634 635 with testing.db.begin() as conn: 636 conn.execute(stmt.compile(dialect=d1)) 637 conn.execute(stmt.compile(dialect=d2)) 638 639 eq_( 640 do_exec.mock_calls, 641 [ 642 call( 643 mock.ANY, 644 "UPDATE users SET user_id=%s, user_name=%s", 645 (1, "foo"), 646 mock.ANY, 647 ), 648 call( 649 mock.ANY, 650 "UPDATE users SET user_id=%(user_id)s, " 651 "user_name=%(user_name)s", 652 {"user_name": "foo", "user_id": 1}, 653 mock.ANY, 654 ), 655 ], 656 ) 657 658 @testing.requires.ad_hoc_engines 659 def test_engine_level_options(self): 660 eng = engines.testing_engine( 661 options={"execution_options": {"foo": "bar"}} 662 ) 663 with eng.connect() as conn: 664 eq_(conn._execution_options["foo"], "bar") 665 eq_( 666 conn.execution_options(bat="hoho")._execution_options["foo"], 667 "bar", 668 ) 669 eq_( 670 conn.execution_options(bat="hoho")._execution_options["bat"], 671 "hoho", 672 ) 673 eq_( 674 conn.execution_options(foo="hoho")._execution_options["foo"], 675 "hoho", 676 ) 677 eng.update_execution_options(foo="hoho") 678 conn = eng.connect() 679 eq_(conn._execution_options["foo"], "hoho") 680 681 @testing.requires.ad_hoc_engines 682 def test_generative_engine_execution_options(self): 683 eng = engines.testing_engine( 684 options={"execution_options": {"base": "x1"}} 685 ) 686 687 is_(eng.engine, eng) 688 689 eng1 = eng.execution_options(foo="b1") 690 is_(eng1.engine, eng1) 691 eng2 = eng.execution_options(foo="b2") 692 eng1a = eng1.execution_options(bar="a1") 693 eng2a = eng2.execution_options(foo="b3", bar="a2") 694 is_(eng2a.engine, eng2a) 695 696 eq_(eng._execution_options, {"base": "x1"}) 697 eq_(eng1._execution_options, {"base": "x1", "foo": "b1"}) 698 eq_(eng2._execution_options, {"base": "x1", "foo": "b2"}) 699 eq_(eng1a._execution_options, {"base": "x1", "foo": "b1", "bar": "a1"}) 700 eq_(eng2a._execution_options, {"base": "x1", "foo": "b3", "bar": "a2"}) 701 is_(eng1a.pool, eng.pool) 702 703 # test pool is shared 704 eng2.dispose() 705 is_(eng1a.pool, eng2.pool) 706 is_(eng.pool, eng2.pool) 707 708 @testing.requires.ad_hoc_engines 709 def test_autocommit_option_no_issue_first_connect(self): 710 eng = create_engine(testing.db.url) 711 eng.update_execution_options(autocommit=True) 712 conn = eng.connect() 713 eq_(conn._execution_options, {"autocommit": True}) 714 conn.close() 715 716 def test_initialize_rollback(self): 717 """test a rollback happens during first connect""" 718 eng = create_engine(testing.db.url) 719 with patch.object(eng.dialect, "do_rollback") as do_rollback: 720 assert do_rollback.call_count == 0 721 connection = eng.connect() 722 assert do_rollback.call_count == 1 723 connection.close() 724 725 @testing.requires.ad_hoc_engines 726 def test_dialect_init_uses_options(self): 727 eng = create_engine(testing.db.url) 728 729 def my_init(connection): 730 connection.execution_options(foo="bar").execute(select(1)) 731 732 with patch.object(eng.dialect, "initialize", my_init): 733 conn = eng.connect() 734 eq_(conn._execution_options, {}) 735 conn.close() 736 737 @testing.requires.ad_hoc_engines 738 def test_generative_engine_event_dispatch_hasevents(self): 739 def l1(*arg, **kw): 740 pass 741 742 eng = create_engine(testing.db.url) 743 assert not eng._has_events 744 event.listen(eng, "before_execute", l1) 745 eng2 = eng.execution_options(foo="bar") 746 assert eng2._has_events 747 748 def test_works_after_dispose(self): 749 eng = create_engine(testing.db.url) 750 for i in range(3): 751 with eng.connect() as conn: 752 eq_(conn.scalar(select(1)), 1) 753 eng.dispose() 754 755 def test_works_after_dispose_testing_engine(self): 756 eng = engines.testing_engine() 757 for i in range(3): 758 with eng.connect() as conn: 759 eq_(conn.scalar(select(1)), 1) 760 eng.dispose() 761 762 def test_scalar(self, connection): 763 conn = connection 764 users = self.tables.users 765 conn.execute( 766 users.insert(), 767 [ 768 {"user_id": 1, "user_name": "sandy"}, 769 {"user_id": 2, "user_name": "spongebob"}, 770 ], 771 ) 772 res = conn.scalar(select(users.c.user_name).order_by(users.c.user_id)) 773 eq_(res, "sandy") 774 775 def test_scalars(self, connection): 776 conn = connection 777 users = self.tables.users 778 conn.execute( 779 users.insert(), 780 [ 781 {"user_id": 1, "user_name": "sandy"}, 782 {"user_id": 2, "user_name": "spongebob"}, 783 ], 784 ) 785 res = conn.scalars(select(users.c.user_name).order_by(users.c.user_id)) 786 eq_(res.all(), ["sandy", "spongebob"]) 787 788 789class UnicodeReturnsTest(fixtures.TestBase): 790 @testing.requires.python3 791 def test_unicode_test_not_in_python3(self): 792 eng = engines.testing_engine() 793 eng.dialect.returns_unicode_strings = String.RETURNS_UNKNOWN 794 795 assert_raises_message( 796 tsa.exc.InvalidRequestError, 797 "RETURNS_UNKNOWN is unsupported in Python 3", 798 eng.connect, 799 ) 800 801 @testing.requires.python2 802 def test_unicode_test_fails_warning(self): 803 class MockCursor(engines.DBAPIProxyCursor): 804 def execute(self, stmt, params=None, **kw): 805 if "test unicode returns" in stmt: 806 raise self.engine.dialect.dbapi.DatabaseError("boom") 807 else: 808 return super(MockCursor, self).execute(stmt, params, **kw) 809 810 eng = engines.proxying_engine(cursor_cls=MockCursor) 811 with testing.expect_warnings( 812 "Exception attempting to detect unicode returns" 813 ): 814 eng.connect() 815 816 # because plain varchar passed, we don't know the correct answer 817 eq_(eng.dialect.returns_unicode_strings, String.RETURNS_CONDITIONAL) 818 eng.dispose() 819 820 821class ConvenienceExecuteTest(fixtures.TablesTest): 822 __backend__ = True 823 824 @classmethod 825 def define_tables(cls, metadata): 826 cls.table = Table( 827 "exec_test", 828 metadata, 829 Column("a", Integer), 830 Column("b", Integer), 831 test_needs_acid=True, 832 ) 833 834 def _trans_fn(self, is_transaction=False): 835 def go(conn, x, value=None): 836 if is_transaction: 837 conn = conn.connection 838 conn.execute(self.table.insert().values(a=x, b=value)) 839 840 return go 841 842 def _trans_rollback_fn(self, is_transaction=False): 843 def go(conn, x, value=None): 844 if is_transaction: 845 conn = conn.connection 846 conn.execute(self.table.insert().values(a=x, b=value)) 847 raise SomeException("breakage") 848 849 return go 850 851 def _assert_no_data(self): 852 with testing.db.connect() as conn: 853 eq_( 854 conn.scalar(select(func.count("*")).select_from(self.table)), 855 0, 856 ) 857 858 def _assert_fn(self, x, value=None): 859 with testing.db.connect() as conn: 860 eq_(conn.execute(self.table.select()).fetchall(), [(x, value)]) 861 862 def test_transaction_engine_ctx_commit(self): 863 fn = self._trans_fn() 864 ctx = testing.db.begin() 865 testing.run_as_contextmanager(ctx, fn, 5, value=8) 866 self._assert_fn(5, value=8) 867 868 def test_transaction_engine_ctx_begin_fails_dont_enter_enter(self): 869 """test #7272""" 870 engine = engines.testing_engine() 871 872 mock_connection = Mock( 873 return_value=Mock(begin=Mock(side_effect=Exception("boom"))) 874 ) 875 with mock.patch.object(engine, "_connection_cls", mock_connection): 876 if testing.requires.legacy_engine.enabled: 877 with expect_raises_message(Exception, "boom"): 878 engine.begin() 879 else: 880 # context manager isn't entered, doesn't actually call 881 # connect() or connection.begin() 882 engine.begin() 883 884 if testing.requires.legacy_engine.enabled: 885 eq_(mock_connection.return_value.close.mock_calls, [call()]) 886 else: 887 eq_(mock_connection.return_value.close.mock_calls, []) 888 889 def test_transaction_engine_ctx_begin_fails_include_enter(self): 890 """test #7272""" 891 engine = engines.testing_engine() 892 893 close_mock = Mock() 894 with mock.patch.object( 895 engine._connection_cls, 896 "begin", 897 Mock(side_effect=Exception("boom")), 898 ), mock.patch.object(engine._connection_cls, "close", close_mock): 899 with expect_raises_message(Exception, "boom"): 900 with engine.begin(): 901 pass 902 903 eq_(close_mock.mock_calls, [call()]) 904 905 def test_transaction_engine_ctx_rollback(self): 906 fn = self._trans_rollback_fn() 907 ctx = testing.db.begin() 908 assert_raises_message( 909 Exception, 910 "breakage", 911 testing.run_as_contextmanager, 912 ctx, 913 fn, 914 5, 915 value=8, 916 ) 917 self._assert_no_data() 918 919 def test_transaction_connection_ctx_commit(self): 920 fn = self._trans_fn(True) 921 with testing.db.connect() as conn: 922 ctx = conn.begin() 923 testing.run_as_contextmanager(ctx, fn, 5, value=8) 924 self._assert_fn(5, value=8) 925 926 def test_transaction_connection_ctx_rollback(self): 927 fn = self._trans_rollback_fn(True) 928 with testing.db.connect() as conn: 929 ctx = conn.begin() 930 assert_raises_message( 931 Exception, 932 "breakage", 933 testing.run_as_contextmanager, 934 ctx, 935 fn, 936 5, 937 value=8, 938 ) 939 self._assert_no_data() 940 941 def test_connection_as_ctx(self): 942 fn = self._trans_fn() 943 with testing.db.begin() as conn: 944 fn(conn, 5, value=8) 945 self._assert_fn(5, value=8) 946 947 @testing.fails_on("mysql+oursql", "oursql bug ? getting wrong rowcount") 948 @testing.requires.legacy_engine 949 def test_connect_as_ctx_noautocommit(self): 950 fn = self._trans_fn() 951 self._assert_no_data() 952 953 with testing.db.connect() as conn: 954 ctx = conn.execution_options(autocommit=False) 955 testing.run_as_contextmanager(ctx, fn, 5, value=8) 956 # autocommit is off 957 self._assert_no_data() 958 959 960class FutureConvenienceExecuteTest( 961 fixtures.FutureEngineMixin, ConvenienceExecuteTest 962): 963 __backend__ = True 964 965 966class CompiledCacheTest(fixtures.TestBase): 967 __backend__ = True 968 969 def test_cache(self, connection, metadata): 970 users = Table( 971 "users", 972 metadata, 973 Column( 974 "user_id", INT, primary_key=True, test_needs_autoincrement=True 975 ), 976 Column("user_name", VARCHAR(20)), 977 Column("extra_data", VARCHAR(20)), 978 ) 979 users.create(connection) 980 981 conn = connection 982 cache = {} 983 cached_conn = conn.execution_options(compiled_cache=cache) 984 985 ins = users.insert() 986 with patch.object( 987 ins, "_compiler", Mock(side_effect=ins._compiler) 988 ) as compile_mock: 989 cached_conn.execute(ins, {"user_name": "u1"}) 990 cached_conn.execute(ins, {"user_name": "u2"}) 991 cached_conn.execute(ins, {"user_name": "u3"}) 992 eq_(compile_mock.call_count, 1) 993 assert len(cache) == 1 994 eq_(conn.exec_driver_sql("select count(*) from users").scalar(), 3) 995 996 @testing.only_on( 997 ["sqlite", "mysql", "postgresql"], 998 "uses blob value that is problematic for some DBAPIs", 999 ) 1000 def test_cache_noleak_on_statement_values(self, metadata, connection): 1001 # This is a non regression test for an object reference leak caused 1002 # by the compiled_cache. 1003 1004 photo = Table( 1005 "photo", 1006 metadata, 1007 Column( 1008 "id", Integer, primary_key=True, test_needs_autoincrement=True 1009 ), 1010 Column("photo_blob", LargeBinary()), 1011 ) 1012 metadata.create_all(connection) 1013 1014 cache = {} 1015 cached_conn = connection.execution_options(compiled_cache=cache) 1016 1017 class PhotoBlob(bytearray): 1018 pass 1019 1020 blob = PhotoBlob(100) 1021 ref_blob = weakref.ref(blob) 1022 1023 ins = photo.insert() 1024 with patch.object( 1025 ins, "_compiler", Mock(side_effect=ins._compiler) 1026 ) as compile_mock: 1027 cached_conn.execute(ins, {"photo_blob": blob}) 1028 eq_(compile_mock.call_count, 1) 1029 eq_(len(cache), 1) 1030 eq_( 1031 connection.exec_driver_sql("select count(*) from photo").scalar(), 1032 1, 1033 ) 1034 1035 del blob 1036 1037 gc_collect() 1038 1039 # The compiled statement cache should not hold any reference to the 1040 # the statement values (only the keys). 1041 eq_(ref_blob(), None) 1042 1043 def test_keys_independent_of_ordering(self, connection, metadata): 1044 users = Table( 1045 "users", 1046 metadata, 1047 Column( 1048 "user_id", INT, primary_key=True, test_needs_autoincrement=True 1049 ), 1050 Column("user_name", VARCHAR(20)), 1051 Column("extra_data", VARCHAR(20)), 1052 ) 1053 users.create(connection) 1054 1055 connection.execute( 1056 users.insert(), 1057 {"user_id": 1, "user_name": "u1", "extra_data": "e1"}, 1058 ) 1059 cache = {} 1060 cached_conn = connection.execution_options(compiled_cache=cache) 1061 1062 upd = users.update().where(users.c.user_id == bindparam("b_user_id")) 1063 1064 with patch.object( 1065 upd, "_compiler", Mock(side_effect=upd._compiler) 1066 ) as compile_mock: 1067 cached_conn.execute( 1068 upd, 1069 util.OrderedDict( 1070 [ 1071 ("b_user_id", 1), 1072 ("user_name", "u2"), 1073 ("extra_data", "e2"), 1074 ] 1075 ), 1076 ) 1077 cached_conn.execute( 1078 upd, 1079 util.OrderedDict( 1080 [ 1081 ("b_user_id", 1), 1082 ("extra_data", "e3"), 1083 ("user_name", "u3"), 1084 ] 1085 ), 1086 ) 1087 cached_conn.execute( 1088 upd, 1089 util.OrderedDict( 1090 [ 1091 ("extra_data", "e4"), 1092 ("user_name", "u4"), 1093 ("b_user_id", 1), 1094 ] 1095 ), 1096 ) 1097 eq_(compile_mock.call_count, 1) 1098 eq_(len(cache), 1) 1099 1100 @testing.requires.schemas 1101 def test_schema_translate_in_key(self, metadata, connection): 1102 Table("x", metadata, Column("q", Integer)) 1103 Table("x", metadata, Column("q", Integer), schema=config.test_schema) 1104 metadata.create_all(connection) 1105 1106 m = MetaData() 1107 t1 = Table("x", m, Column("q", Integer)) 1108 ins = t1.insert() 1109 stmt = select(t1.c.q) 1110 1111 cache = {} 1112 1113 conn = connection.execution_options(compiled_cache=cache) 1114 conn.execute(ins, {"q": 1}) 1115 eq_(conn.scalar(stmt), 1) 1116 1117 conn = connection.execution_options( 1118 compiled_cache=cache, 1119 schema_translate_map={None: config.test_schema}, 1120 ) 1121 conn.execute(ins, {"q": 2}) 1122 eq_(conn.scalar(stmt), 2) 1123 1124 conn = connection.execution_options( 1125 compiled_cache=cache, 1126 schema_translate_map={None: None}, 1127 ) 1128 # should use default schema again even though statement 1129 # was compiled with test_schema in the map 1130 eq_(conn.scalar(stmt), 1) 1131 1132 conn = connection.execution_options( 1133 compiled_cache=cache, 1134 ) 1135 eq_(conn.scalar(stmt), 1) 1136 1137 1138class MockStrategyTest(fixtures.TestBase): 1139 def _engine_fixture(self): 1140 buf = util.StringIO() 1141 1142 def dump(sql, *multiparams, **params): 1143 buf.write(util.text_type(sql.compile(dialect=engine.dialect))) 1144 1145 engine = create_mock_engine("postgresql://", executor=dump) 1146 return engine, buf 1147 1148 def test_sequence_not_duped(self): 1149 engine, buf = self._engine_fixture() 1150 metadata = MetaData() 1151 t = Table( 1152 "testtable", 1153 metadata, 1154 Column( 1155 "pk", 1156 Integer, 1157 Sequence("testtable_pk_seq"), 1158 primary_key=True, 1159 ), 1160 ) 1161 1162 t.create(engine) 1163 t.drop(engine) 1164 1165 eq_(re.findall(r"CREATE (\w+)", buf.getvalue()), ["SEQUENCE", "TABLE"]) 1166 1167 eq_(re.findall(r"DROP (\w+)", buf.getvalue()), ["TABLE", "SEQUENCE"]) 1168 1169 1170class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): 1171 __requires__ = ("schemas",) 1172 __backend__ = True 1173 1174 @testing.fixture 1175 def plain_tables(self, metadata): 1176 t1 = Table( 1177 "t1", metadata, Column("x", Integer), schema=config.test_schema 1178 ) 1179 t2 = Table( 1180 "t2", metadata, Column("x", Integer), schema=config.test_schema 1181 ) 1182 t3 = Table("t3", metadata, Column("x", Integer), schema=None) 1183 1184 return t1, t2, t3 1185 1186 def test_create_table(self, plain_tables, connection): 1187 map_ = { 1188 None: config.test_schema, 1189 "foo": config.test_schema, 1190 "bar": None, 1191 } 1192 1193 metadata = MetaData() 1194 t1 = Table("t1", metadata, Column("x", Integer)) 1195 t2 = Table("t2", metadata, Column("x", Integer), schema="foo") 1196 t3 = Table("t3", metadata, Column("x", Integer), schema="bar") 1197 1198 with self.sql_execution_asserter(connection) as asserter: 1199 conn = connection.execution_options(schema_translate_map=map_) 1200 1201 t1.create(conn) 1202 t2.create(conn) 1203 t3.create(conn) 1204 1205 t3.drop(conn) 1206 t2.drop(conn) 1207 t1.drop(conn) 1208 1209 asserter.assert_( 1210 CompiledSQL("CREATE TABLE __[SCHEMA__none].t1 (x INTEGER)"), 1211 CompiledSQL("CREATE TABLE __[SCHEMA_foo].t2 (x INTEGER)"), 1212 CompiledSQL("CREATE TABLE __[SCHEMA_bar].t3 (x INTEGER)"), 1213 CompiledSQL("DROP TABLE __[SCHEMA_bar].t3"), 1214 CompiledSQL("DROP TABLE __[SCHEMA_foo].t2"), 1215 CompiledSQL("DROP TABLE __[SCHEMA__none].t1"), 1216 ) 1217 1218 def test_ddl_hastable(self, plain_tables, connection): 1219 1220 map_ = { 1221 None: config.test_schema, 1222 "foo": config.test_schema, 1223 "bar": None, 1224 } 1225 1226 metadata = MetaData() 1227 Table("t1", metadata, Column("x", Integer)) 1228 Table("t2", metadata, Column("x", Integer), schema="foo") 1229 Table("t3", metadata, Column("x", Integer), schema="bar") 1230 1231 conn = connection.execution_options(schema_translate_map=map_) 1232 metadata.create_all(conn) 1233 1234 insp = inspect(connection) 1235 is_true(insp.has_table("t1", schema=config.test_schema)) 1236 is_true(insp.has_table("t2", schema=config.test_schema)) 1237 is_true(insp.has_table("t3", schema=None)) 1238 1239 conn = connection.execution_options(schema_translate_map=map_) 1240 1241 # if this test fails, the tables won't get dropped. so need a 1242 # more robust fixture for this 1243 metadata.drop_all(conn) 1244 1245 insp = inspect(connection) 1246 is_false(insp.has_table("t1", schema=config.test_schema)) 1247 is_false(insp.has_table("t2", schema=config.test_schema)) 1248 is_false(insp.has_table("t3", schema=None)) 1249 1250 def test_option_on_execute(self, plain_tables, connection): 1251 # provided by metadata fixture provided by plain_tables fixture 1252 self.metadata.create_all(connection) 1253 1254 map_ = { 1255 None: config.test_schema, 1256 "foo": config.test_schema, 1257 "bar": None, 1258 } 1259 1260 metadata = MetaData() 1261 t1 = Table("t1", metadata, Column("x", Integer)) 1262 t2 = Table("t2", metadata, Column("x", Integer), schema="foo") 1263 t3 = Table("t3", metadata, Column("x", Integer), schema="bar") 1264 1265 with self.sql_execution_asserter(connection) as asserter: 1266 conn = connection 1267 execution_options = {"schema_translate_map": map_} 1268 conn._execute_20( 1269 t1.insert(), {"x": 1}, execution_options=execution_options 1270 ) 1271 conn._execute_20( 1272 t2.insert(), {"x": 1}, execution_options=execution_options 1273 ) 1274 conn._execute_20( 1275 t3.insert(), {"x": 1}, execution_options=execution_options 1276 ) 1277 1278 conn._execute_20( 1279 t1.update().values(x=1).where(t1.c.x == 1), 1280 execution_options=execution_options, 1281 ) 1282 conn._execute_20( 1283 t2.update().values(x=2).where(t2.c.x == 1), 1284 execution_options=execution_options, 1285 ) 1286 conn._execute_20( 1287 t3.update().values(x=3).where(t3.c.x == 1), 1288 execution_options=execution_options, 1289 ) 1290 1291 eq_( 1292 conn._execute_20( 1293 select(t1.c.x), execution_options=execution_options 1294 ).scalar(), 1295 1, 1296 ) 1297 eq_( 1298 conn._execute_20( 1299 select(t2.c.x), execution_options=execution_options 1300 ).scalar(), 1301 2, 1302 ) 1303 eq_( 1304 conn._execute_20( 1305 select(t3.c.x), execution_options=execution_options 1306 ).scalar(), 1307 3, 1308 ) 1309 1310 conn._execute_20(t1.delete(), execution_options=execution_options) 1311 conn._execute_20(t2.delete(), execution_options=execution_options) 1312 conn._execute_20(t3.delete(), execution_options=execution_options) 1313 1314 asserter.assert_( 1315 CompiledSQL("INSERT INTO __[SCHEMA__none].t1 (x) VALUES (:x)"), 1316 CompiledSQL("INSERT INTO __[SCHEMA_foo].t2 (x) VALUES (:x)"), 1317 CompiledSQL("INSERT INTO __[SCHEMA_bar].t3 (x) VALUES (:x)"), 1318 CompiledSQL( 1319 "UPDATE __[SCHEMA__none].t1 SET x=:x WHERE " 1320 "__[SCHEMA__none].t1.x = :x_1" 1321 ), 1322 CompiledSQL( 1323 "UPDATE __[SCHEMA_foo].t2 SET x=:x WHERE " 1324 "__[SCHEMA_foo].t2.x = :x_1" 1325 ), 1326 CompiledSQL( 1327 "UPDATE __[SCHEMA_bar].t3 SET x=:x WHERE " 1328 "__[SCHEMA_bar].t3.x = :x_1" 1329 ), 1330 CompiledSQL( 1331 "SELECT __[SCHEMA__none].t1.x FROM __[SCHEMA__none].t1" 1332 ), 1333 CompiledSQL("SELECT __[SCHEMA_foo].t2.x FROM __[SCHEMA_foo].t2"), 1334 CompiledSQL("SELECT __[SCHEMA_bar].t3.x FROM __[SCHEMA_bar].t3"), 1335 CompiledSQL("DELETE FROM __[SCHEMA__none].t1"), 1336 CompiledSQL("DELETE FROM __[SCHEMA_foo].t2"), 1337 CompiledSQL("DELETE FROM __[SCHEMA_bar].t3"), 1338 ) 1339 1340 def test_crud(self, plain_tables, connection): 1341 # provided by metadata fixture provided by plain_tables fixture 1342 self.metadata.create_all(connection) 1343 1344 map_ = { 1345 None: config.test_schema, 1346 "foo": config.test_schema, 1347 "bar": None, 1348 } 1349 1350 metadata = MetaData() 1351 t1 = Table("t1", metadata, Column("x", Integer)) 1352 t2 = Table("t2", metadata, Column("x", Integer), schema="foo") 1353 t3 = Table("t3", metadata, Column("x", Integer), schema="bar") 1354 1355 with self.sql_execution_asserter(connection) as asserter: 1356 conn = connection.execution_options(schema_translate_map=map_) 1357 1358 conn.execute(t1.insert(), {"x": 1}) 1359 conn.execute(t2.insert(), {"x": 1}) 1360 conn.execute(t3.insert(), {"x": 1}) 1361 1362 conn.execute(t1.update().values(x=1).where(t1.c.x == 1)) 1363 conn.execute(t2.update().values(x=2).where(t2.c.x == 1)) 1364 conn.execute(t3.update().values(x=3).where(t3.c.x == 1)) 1365 1366 eq_(conn.scalar(select(t1.c.x)), 1) 1367 eq_(conn.scalar(select(t2.c.x)), 2) 1368 eq_(conn.scalar(select(t3.c.x)), 3) 1369 1370 conn.execute(t1.delete()) 1371 conn.execute(t2.delete()) 1372 conn.execute(t3.delete()) 1373 1374 asserter.assert_( 1375 CompiledSQL("INSERT INTO __[SCHEMA__none].t1 (x) VALUES (:x)"), 1376 CompiledSQL("INSERT INTO __[SCHEMA_foo].t2 (x) VALUES (:x)"), 1377 CompiledSQL("INSERT INTO __[SCHEMA_bar].t3 (x) VALUES (:x)"), 1378 CompiledSQL( 1379 "UPDATE __[SCHEMA__none].t1 SET x=:x WHERE " 1380 "__[SCHEMA__none].t1.x = :x_1" 1381 ), 1382 CompiledSQL( 1383 "UPDATE __[SCHEMA_foo].t2 SET x=:x WHERE " 1384 "__[SCHEMA_foo].t2.x = :x_1" 1385 ), 1386 CompiledSQL( 1387 "UPDATE __[SCHEMA_bar].t3 SET x=:x WHERE " 1388 "__[SCHEMA_bar].t3.x = :x_1" 1389 ), 1390 CompiledSQL( 1391 "SELECT __[SCHEMA__none].t1.x FROM __[SCHEMA__none].t1" 1392 ), 1393 CompiledSQL("SELECT __[SCHEMA_foo].t2.x FROM __[SCHEMA_foo].t2"), 1394 CompiledSQL("SELECT __[SCHEMA_bar].t3.x FROM __[SCHEMA_bar].t3"), 1395 CompiledSQL("DELETE FROM __[SCHEMA__none].t1"), 1396 CompiledSQL("DELETE FROM __[SCHEMA_foo].t2"), 1397 CompiledSQL("DELETE FROM __[SCHEMA_bar].t3"), 1398 ) 1399 1400 def test_via_engine(self, plain_tables, metadata): 1401 1402 with config.db.begin() as connection: 1403 metadata.create_all(connection) 1404 1405 map_ = { 1406 None: config.test_schema, 1407 "foo": config.test_schema, 1408 "bar": None, 1409 } 1410 1411 metadata = MetaData() 1412 t2 = Table("t2", metadata, Column("x", Integer), schema="foo") 1413 1414 with self.sql_execution_asserter(config.db) as asserter: 1415 eng = config.db.execution_options(schema_translate_map=map_) 1416 with eng.connect() as conn: 1417 conn.execute(select(t2.c.x)) 1418 asserter.assert_( 1419 CompiledSQL("SELECT __[SCHEMA_foo].t2.x FROM __[SCHEMA_foo].t2") 1420 ) 1421 1422 1423class ExecutionOptionsTest(fixtures.TestBase): 1424 def test_dialect_conn_options(self, testing_engine): 1425 engine = testing_engine("sqlite://", options=dict(_initialize=False)) 1426 engine.dialect = Mock() 1427 with engine.connect() as conn: 1428 c2 = conn.execution_options(foo="bar") 1429 eq_( 1430 engine.dialect.set_connection_execution_options.mock_calls, 1431 [call(c2, {"foo": "bar"})], 1432 ) 1433 1434 def test_dialect_engine_options(self, testing_engine): 1435 engine = testing_engine("sqlite://") 1436 engine.dialect = Mock() 1437 e2 = engine.execution_options(foo="bar") 1438 eq_( 1439 engine.dialect.set_engine_execution_options.mock_calls, 1440 [call(e2, {"foo": "bar"})], 1441 ) 1442 1443 def test_dialect_engine_construction_options(self): 1444 dialect = Mock() 1445 engine = Engine( 1446 Mock(), dialect, Mock(), execution_options={"foo": "bar"} 1447 ) 1448 eq_( 1449 dialect.set_engine_execution_options.mock_calls, 1450 [call(engine, {"foo": "bar"})], 1451 ) 1452 1453 def test_propagate_engine_to_connection(self, testing_engine): 1454 engine = testing_engine( 1455 "sqlite://", options=dict(execution_options={"foo": "bar"}) 1456 ) 1457 with engine.connect() as conn: 1458 eq_(conn._execution_options, {"foo": "bar"}) 1459 1460 def test_propagate_option_engine_to_connection(self, testing_engine): 1461 e1 = testing_engine( 1462 "sqlite://", options=dict(execution_options={"foo": "bar"}) 1463 ) 1464 e2 = e1.execution_options(bat="hoho") 1465 c1 = e1.connect() 1466 c2 = e2.connect() 1467 eq_(c1._execution_options, {"foo": "bar"}) 1468 eq_(c2._execution_options, {"foo": "bar", "bat": "hoho"}) 1469 1470 c1.close() 1471 c2.close() 1472 1473 def test_get_engine_execution_options(self, testing_engine): 1474 engine = testing_engine("sqlite://") 1475 engine.dialect = Mock() 1476 e2 = engine.execution_options(foo="bar") 1477 1478 eq_(e2.get_execution_options(), {"foo": "bar"}) 1479 1480 def test_get_connection_execution_options(self, testing_engine): 1481 engine = testing_engine("sqlite://", options=dict(_initialize=False)) 1482 engine.dialect = Mock() 1483 with engine.connect() as conn: 1484 c = conn.execution_options(foo="bar") 1485 1486 eq_(c.get_execution_options(), {"foo": "bar"}) 1487 1488 1489class EngineEventsTest(fixtures.TestBase): 1490 __requires__ = ("ad_hoc_engines",) 1491 __backend__ = True 1492 1493 def teardown_test(self): 1494 Engine.dispatch._clear() 1495 Engine._has_events = False 1496 1497 def _assert_stmts(self, expected, received): 1498 list(received) 1499 1500 for stmt, params, posn in expected: 1501 if not received: 1502 assert False, "Nothing available for stmt: %s" % stmt 1503 while received: 1504 teststmt, testparams, testmultiparams = received.pop(0) 1505 teststmt = ( 1506 re.compile(r"[\n\t ]+", re.M).sub(" ", teststmt).strip() 1507 ) 1508 if teststmt.startswith(stmt) and ( 1509 testparams == params or testparams == posn 1510 ): 1511 break 1512 1513 def test_per_engine_independence(self, testing_engine): 1514 e1 = testing_engine(config.db_url) 1515 e2 = testing_engine(config.db_url) 1516 1517 canary = Mock() 1518 event.listen(e1, "before_execute", canary) 1519 s1 = select(1) 1520 s2 = select(2) 1521 1522 with e1.connect() as conn: 1523 conn.execute(s1) 1524 1525 with e2.connect() as conn: 1526 conn.execute(s2) 1527 eq_([arg[1][1] for arg in canary.mock_calls], [s1]) 1528 event.listen(e2, "before_execute", canary) 1529 1530 with e1.connect() as conn: 1531 conn.execute(s1) 1532 1533 with e2.connect() as conn: 1534 conn.execute(s2) 1535 eq_([arg[1][1] for arg in canary.mock_calls], [s1, s1, s2]) 1536 1537 def test_per_engine_plus_global(self, testing_engine): 1538 canary = Mock() 1539 event.listen(Engine, "before_execute", canary.be1) 1540 e1 = testing_engine(config.db_url) 1541 e2 = testing_engine(config.db_url) 1542 1543 event.listen(e1, "before_execute", canary.be2) 1544 1545 event.listen(Engine, "before_execute", canary.be3) 1546 1547 with e1.connect() as conn: 1548 conn.execute(select(1)) 1549 eq_(canary.be1.call_count, 1) 1550 eq_(canary.be2.call_count, 1) 1551 1552 with e2.connect() as conn: 1553 conn.execute(select(1)) 1554 1555 eq_(canary.be1.call_count, 2) 1556 eq_(canary.be2.call_count, 1) 1557 eq_(canary.be3.call_count, 2) 1558 1559 def test_emit_sql_in_autobegin(self, testing_engine): 1560 e1 = testing_engine(config.db_url) 1561 1562 canary = Mock() 1563 1564 @event.listens_for(e1, "begin") 1565 def begin(connection): 1566 result = connection.execute(select(1)).scalar() 1567 canary.got_result(result) 1568 1569 with e1.connect() as conn: 1570 assert not conn._is_future 1571 1572 with conn.begin(): 1573 conn.execute(select(1)).scalar() 1574 assert conn.in_transaction() 1575 1576 assert not conn.in_transaction() 1577 1578 eq_(canary.mock_calls, [call.got_result(1)]) 1579 1580 def test_per_connection_plus_engine(self, testing_engine): 1581 canary = Mock() 1582 e1 = testing_engine(config.db_url) 1583 1584 event.listen(e1, "before_execute", canary.be1) 1585 1586 conn = e1.connect() 1587 event.listen(conn, "before_execute", canary.be2) 1588 conn.execute(select(1)) 1589 1590 eq_(canary.be1.call_count, 1) 1591 eq_(canary.be2.call_count, 1) 1592 1593 if testing.requires.legacy_engine.enabled: 1594 conn._branch().execute(select(1)) 1595 eq_(canary.be1.call_count, 2) 1596 eq_(canary.be2.call_count, 2) 1597 1598 @testing.combinations( 1599 (True, False), 1600 (True, True), 1601 (False, False), 1602 argnames="mock_out_on_connect, add_our_own_onconnect", 1603 ) 1604 def test_insert_connect_is_definitely_first( 1605 self, mock_out_on_connect, add_our_own_onconnect, testing_engine 1606 ): 1607 """test issue #5708. 1608 1609 We want to ensure that a single "connect" event may be invoked 1610 *before* dialect initialize as well as before dialect on_connects. 1611 1612 This is also partially reliant on the changes we made as a result of 1613 #5497, however here we go further with the changes and remove use 1614 of the pool first_connect() event entirely so that the startup 1615 for a dialect is fully consistent. 1616 1617 """ 1618 if mock_out_on_connect: 1619 if add_our_own_onconnect: 1620 1621 def our_connect(connection): 1622 m1.our_connect("our connect event") 1623 1624 patcher = mock.patch.object( 1625 config.db.dialect.__class__, 1626 "on_connect", 1627 lambda self: our_connect, 1628 ) 1629 else: 1630 patcher = mock.patch.object( 1631 config.db.dialect.__class__, 1632 "on_connect", 1633 lambda self: None, 1634 ) 1635 else: 1636 patcher = util.nullcontext() 1637 1638 with patcher: 1639 e1 = testing_engine(config.db_url) 1640 1641 initialize = e1.dialect.initialize 1642 1643 def init(connection): 1644 initialize(connection) 1645 1646 with mock.patch.object( 1647 e1.dialect, "initialize", side_effect=init 1648 ) as m1: 1649 1650 @event.listens_for(e1, "connect", insert=True) 1651 def go1(dbapi_conn, xyz): 1652 m1.foo("custom event first") 1653 1654 @event.listens_for(e1, "connect") 1655 def go2(dbapi_conn, xyz): 1656 m1.foo("custom event last") 1657 1658 c1 = e1.connect() 1659 1660 m1.bar("ok next connection") 1661 1662 c2 = e1.connect() 1663 1664 # this happens with sqlite singletonthreadpool. 1665 # we can almost use testing.requires.independent_connections 1666 # but sqlite file backend will also have independent 1667 # connections here. 1668 its_the_same_connection = ( 1669 c1.connection.dbapi_connection 1670 is c2.connection.dbapi_connection 1671 ) 1672 c1.close() 1673 c2.close() 1674 1675 if add_our_own_onconnect: 1676 calls = [ 1677 mock.call.foo("custom event first"), 1678 mock.call.our_connect("our connect event"), 1679 mock.call(mock.ANY), 1680 mock.call.foo("custom event last"), 1681 mock.call.bar("ok next connection"), 1682 ] 1683 else: 1684 calls = [ 1685 mock.call.foo("custom event first"), 1686 mock.call(mock.ANY), 1687 mock.call.foo("custom event last"), 1688 mock.call.bar("ok next connection"), 1689 ] 1690 1691 if not its_the_same_connection: 1692 if add_our_own_onconnect: 1693 calls.extend( 1694 [ 1695 mock.call.foo("custom event first"), 1696 mock.call.our_connect("our connect event"), 1697 mock.call.foo("custom event last"), 1698 ] 1699 ) 1700 else: 1701 calls.extend( 1702 [ 1703 mock.call.foo("custom event first"), 1704 mock.call.foo("custom event last"), 1705 ] 1706 ) 1707 eq_(m1.mock_calls, calls) 1708 1709 def test_new_exec_driver_sql_no_events(self): 1710 m1 = Mock() 1711 1712 def select1(db): 1713 return str(select(1).compile(dialect=db.dialect)) 1714 1715 with testing.db.connect() as conn: 1716 event.listen(conn, "before_execute", m1.before_execute) 1717 event.listen(conn, "after_execute", m1.after_execute) 1718 conn.exec_driver_sql(select1(testing.db)) 1719 eq_(m1.mock_calls, []) 1720 1721 def test_add_event_after_connect(self, testing_engine): 1722 # new feature as of #2978 1723 1724 canary = Mock() 1725 e1 = testing_engine(config.db_url, future=False) 1726 assert not e1._has_events 1727 1728 conn = e1.connect() 1729 1730 event.listen(e1, "before_execute", canary.be1) 1731 conn.execute(select(1)) 1732 1733 eq_(canary.be1.call_count, 1) 1734 1735 conn._branch().execute(select(1)) 1736 eq_(canary.be1.call_count, 2) 1737 1738 def test_force_conn_events_false(self, testing_engine): 1739 canary = Mock() 1740 e1 = testing_engine(config.db_url, future=False) 1741 assert not e1._has_events 1742 1743 event.listen(e1, "before_execute", canary.be1) 1744 1745 conn = e1._connection_cls( 1746 e1, connection=e1.raw_connection(), _has_events=False 1747 ) 1748 1749 conn.execute(select(1)) 1750 1751 eq_(canary.be1.call_count, 0) 1752 1753 conn._branch().execute(select(1)) 1754 eq_(canary.be1.call_count, 0) 1755 1756 def test_cursor_events_ctx_execute_scalar(self, testing_engine): 1757 canary = Mock() 1758 e1 = testing_engine(config.db_url) 1759 1760 event.listen(e1, "before_cursor_execute", canary.bce) 1761 event.listen(e1, "after_cursor_execute", canary.ace) 1762 1763 stmt = str(select(1).compile(dialect=e1.dialect)) 1764 1765 with e1.connect() as conn: 1766 dialect = conn.dialect 1767 1768 ctx = dialect.execution_ctx_cls._init_statement( 1769 dialect, conn, conn.connection, {}, stmt, {} 1770 ) 1771 1772 ctx._execute_scalar(stmt, Integer()) 1773 1774 eq_( 1775 canary.bce.mock_calls, 1776 [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)], 1777 ) 1778 eq_( 1779 canary.ace.mock_calls, 1780 [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)], 1781 ) 1782 1783 def test_cursor_events_execute(self, testing_engine): 1784 canary = Mock() 1785 e1 = testing_engine(config.db_url) 1786 1787 event.listen(e1, "before_cursor_execute", canary.bce) 1788 event.listen(e1, "after_cursor_execute", canary.ace) 1789 1790 stmt = str(select(1).compile(dialect=e1.dialect)) 1791 1792 with e1.connect() as conn: 1793 1794 result = conn.exec_driver_sql(stmt) 1795 eq_(result.scalar(), 1) 1796 1797 ctx = result.context 1798 eq_( 1799 canary.bce.mock_calls, 1800 [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)], 1801 ) 1802 eq_( 1803 canary.ace.mock_calls, 1804 [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)], 1805 ) 1806 1807 @testing.combinations( 1808 ( 1809 ([{"x": 5, "y": 10}, {"x": 8, "y": 9}],), 1810 {}, 1811 [{"x": 5, "y": 10}, {"x": 8, "y": 9}], 1812 {}, 1813 ), 1814 (({"z": 10},), {}, [], {"z": 10}), 1815 argnames="multiparams, params, expected_multiparams, expected_params", 1816 ) 1817 def test_modify_parameters_from_event_one( 1818 self, 1819 multiparams, 1820 params, 1821 expected_multiparams, 1822 expected_params, 1823 testing_engine, 1824 ): 1825 # this is testing both the normalization added to parameters 1826 # as of I97cb4d06adfcc6b889f10d01cc7775925cffb116 as well as 1827 # that the return value from the event is taken as the new set 1828 # of parameters. 1829 def before_execute( 1830 conn, clauseelement, multiparams, params, execution_options 1831 ): 1832 eq_(multiparams, expected_multiparams) 1833 eq_(params, expected_params) 1834 return clauseelement, (), {"q": "15"} 1835 1836 def after_execute( 1837 conn, clauseelement, multiparams, params, result, execution_options 1838 ): 1839 eq_(multiparams, ()) 1840 eq_(params, {"q": "15"}) 1841 1842 e1 = testing_engine(config.db_url) 1843 event.listen(e1, "before_execute", before_execute, retval=True) 1844 event.listen(e1, "after_execute", after_execute) 1845 1846 with e1.connect() as conn: 1847 result = conn.execute( 1848 select(bindparam("q", type_=String)), *multiparams, **params 1849 ) 1850 eq_(result.all(), [("15",)]) 1851 1852 @testing.provide_metadata 1853 def test_modify_parameters_from_event_two(self, connection): 1854 t = Table("t", self.metadata, Column("q", Integer)) 1855 1856 t.create(connection) 1857 1858 def before_execute( 1859 conn, clauseelement, multiparams, params, execution_options 1860 ): 1861 return clauseelement, [{"q": 15}, {"q": 19}], {} 1862 1863 event.listen(connection, "before_execute", before_execute, retval=True) 1864 connection.execute(t.insert(), {"q": 12}) 1865 event.remove(connection, "before_execute", before_execute) 1866 1867 eq_( 1868 connection.execute(select(t).order_by(t.c.q)).fetchall(), 1869 [(15,), (19,)], 1870 ) 1871 1872 def test_modify_parameters_from_event_three( 1873 self, connection, testing_engine 1874 ): 1875 def before_execute( 1876 conn, clauseelement, multiparams, params, execution_options 1877 ): 1878 return clauseelement, [{"q": 15}, {"q": 19}], {"q": 7} 1879 1880 e1 = testing_engine(config.db_url) 1881 event.listen(e1, "before_execute", before_execute, retval=True) 1882 1883 with expect_raises_message( 1884 tsa.exc.InvalidRequestError, 1885 "Event handler can't return non-empty multiparams " 1886 "and params at the same time", 1887 ): 1888 with e1.connect() as conn: 1889 conn.execute(select(literal("1"))) 1890 1891 @testing.only_on("sqlite") 1892 def test_dont_modify_statement_driversql(self, connection): 1893 m1 = mock.Mock() 1894 1895 @event.listens_for(connection, "before_execute", retval=True) 1896 def _modify( 1897 conn, clauseelement, multiparams, params, execution_options 1898 ): 1899 m1.run_event() 1900 return clauseelement.replace("hi", "there"), multiparams, params 1901 1902 # the event does not take effect for the "driver SQL" option 1903 eq_(connection.exec_driver_sql("select 'hi'").scalar(), "hi") 1904 1905 # event is not called at all 1906 eq_(m1.mock_calls, []) 1907 1908 @testing.combinations((True,), (False,), argnames="future") 1909 @testing.only_on("sqlite") 1910 def test_modify_statement_internal_driversql(self, connection, future): 1911 m1 = mock.Mock() 1912 1913 @event.listens_for(connection, "before_execute", retval=True) 1914 def _modify( 1915 conn, clauseelement, multiparams, params, execution_options 1916 ): 1917 m1.run_event() 1918 return clauseelement.replace("hi", "there"), multiparams, params 1919 1920 eq_( 1921 connection._exec_driver_sql( 1922 "select 'hi'", [], {}, {}, future=future 1923 ).scalar(), 1924 "hi" if future else "there", 1925 ) 1926 1927 if future: 1928 eq_(m1.mock_calls, []) 1929 else: 1930 eq_(m1.mock_calls, [call.run_event()]) 1931 1932 def test_modify_statement_clauseelement(self, connection): 1933 @event.listens_for(connection, "before_execute", retval=True) 1934 def _modify( 1935 conn, clauseelement, multiparams, params, execution_options 1936 ): 1937 return select(literal_column("'there'")), multiparams, params 1938 1939 eq_(connection.scalar(select(literal_column("'hi'"))), "there") 1940 1941 def test_argument_format_execute(self, testing_engine): 1942 def before_execute( 1943 conn, clauseelement, multiparams, params, execution_options 1944 ): 1945 assert isinstance(multiparams, (list, tuple)) 1946 assert isinstance(params, collections_abc.Mapping) 1947 1948 def after_execute( 1949 conn, clauseelement, multiparams, params, result, execution_options 1950 ): 1951 assert isinstance(multiparams, (list, tuple)) 1952 assert isinstance(params, collections_abc.Mapping) 1953 1954 e1 = testing_engine(config.db_url) 1955 event.listen(e1, "before_execute", before_execute) 1956 event.listen(e1, "after_execute", after_execute) 1957 1958 with e1.connect() as conn: 1959 conn.execute(select(1)) 1960 conn.execute(select(1).compile(dialect=e1.dialect).statement) 1961 conn.execute(select(1).compile(dialect=e1.dialect)) 1962 1963 conn._execute_compiled( 1964 select(1).compile(dialect=e1.dialect), (), {}, {} 1965 ) 1966 1967 def test_execute_events(self): 1968 1969 stmts = [] 1970 cursor_stmts = [] 1971 1972 def execute( 1973 conn, clauseelement, multiparams, params, execution_options 1974 ): 1975 stmts.append((str(clauseelement), params, multiparams)) 1976 1977 def cursor_execute( 1978 conn, cursor, statement, parameters, context, executemany 1979 ): 1980 cursor_stmts.append((str(statement), parameters, None)) 1981 1982 # TODO: this test is kind of a mess 1983 1984 for engine in [ 1985 engines.testing_engine(options=dict(implicit_returning=False)), 1986 engines.testing_engine( 1987 options=dict(implicit_returning=False) 1988 ).connect(), 1989 ]: 1990 event.listen(engine, "before_execute", execute) 1991 event.listen(engine, "before_cursor_execute", cursor_execute) 1992 m = MetaData() 1993 t1 = Table( 1994 "t1", 1995 m, 1996 Column("c1", Integer, primary_key=True), 1997 Column( 1998 "c2", 1999 String(50), 2000 default=func.lower("Foo"), 2001 primary_key=True, 2002 ), 2003 ) 2004 2005 if isinstance(engine, Connection): 2006 ctx = None 2007 conn = engine 2008 else: 2009 ctx = conn = engine.connect() 2010 2011 trans = conn.begin() 2012 try: 2013 m.create_all(conn, checkfirst=False) 2014 try: 2015 conn.execute(t1.insert(), dict(c1=5, c2="some data")) 2016 conn.execute(t1.insert(), dict(c1=6)) 2017 eq_( 2018 conn.execute(text("select * from t1")).fetchall(), 2019 [(5, "some data"), (6, "foo")], 2020 ) 2021 finally: 2022 m.drop_all(conn) 2023 trans.commit() 2024 finally: 2025 if ctx: 2026 ctx.close() 2027 2028 compiled = [ 2029 ("CREATE TABLE t1", {}, None), 2030 ( 2031 "INSERT INTO t1 (c1, c2)", 2032 {"c2": "some data", "c1": 5}, 2033 (), 2034 ), 2035 ("INSERT INTO t1 (c1, c2)", {"c1": 6}, ()), 2036 ("select * from t1", {}, None), 2037 ("DROP TABLE t1", {}, None), 2038 ] 2039 2040 cursor = [ 2041 ("CREATE TABLE t1", {}, ()), 2042 ( 2043 "INSERT INTO t1 (c1, c2)", 2044 {"c2": "some data", "c1": 5}, 2045 (5, "some data"), 2046 ), 2047 ("SELECT lower", {"lower_2": "Foo"}, ("Foo",)), 2048 ( 2049 "INSERT INTO t1 (c1, c2)", 2050 {"c2": "foo", "c1": 6}, 2051 (6, "foo"), 2052 ), 2053 ("select * from t1", {}, ()), 2054 ("DROP TABLE t1", {}, ()), 2055 ] 2056 self._assert_stmts(compiled, stmts) 2057 self._assert_stmts(cursor, cursor_stmts) 2058 2059 def test_options(self): 2060 canary = [] 2061 2062 def execute(conn, *args, **kw): 2063 canary.append("execute") 2064 2065 def cursor_execute(conn, *args, **kw): 2066 canary.append("cursor_execute") 2067 2068 engine = engines.testing_engine() 2069 event.listen(engine, "before_execute", execute) 2070 event.listen(engine, "before_cursor_execute", cursor_execute) 2071 conn = engine.connect() 2072 c2 = conn.execution_options(foo="bar") 2073 eq_(c2._execution_options, {"foo": "bar"}) 2074 c2.execute(select(1)) 2075 c3 = c2.execution_options(bar="bat") 2076 eq_(c3._execution_options, {"foo": "bar", "bar": "bat"}) 2077 eq_(canary, ["execute", "cursor_execute"]) 2078 2079 @testing.requires.ad_hoc_engines 2080 def test_generative_engine_event_dispatch(self): 2081 canary = [] 2082 2083 def l1(*arg, **kw): 2084 canary.append("l1") 2085 2086 def l2(*arg, **kw): 2087 canary.append("l2") 2088 2089 def l3(*arg, **kw): 2090 canary.append("l3") 2091 2092 eng = engines.testing_engine( 2093 options={"execution_options": {"base": "x1"}} 2094 ) 2095 event.listen(eng, "before_execute", l1) 2096 2097 eng1 = eng.execution_options(foo="b1") 2098 event.listen(eng, "before_execute", l2) 2099 event.listen(eng1, "before_execute", l3) 2100 2101 with eng.connect() as conn: 2102 conn.execute(select(1)) 2103 2104 eq_(canary, ["l1", "l2"]) 2105 2106 with eng1.connect() as conn: 2107 conn.execute(select(1)) 2108 2109 eq_(canary, ["l1", "l2", "l3", "l1", "l2"]) 2110 2111 @testing.requires.ad_hoc_engines 2112 def test_clslevel_engine_event_options(self): 2113 canary = [] 2114 2115 def l1(*arg, **kw): 2116 canary.append("l1") 2117 2118 def l2(*arg, **kw): 2119 canary.append("l2") 2120 2121 def l3(*arg, **kw): 2122 canary.append("l3") 2123 2124 def l4(*arg, **kw): 2125 canary.append("l4") 2126 2127 event.listen(Engine, "before_execute", l1) 2128 2129 eng = engines.testing_engine( 2130 options={"execution_options": {"base": "x1"}} 2131 ) 2132 event.listen(eng, "before_execute", l2) 2133 2134 eng1 = eng.execution_options(foo="b1") 2135 event.listen(eng, "before_execute", l3) 2136 event.listen(eng1, "before_execute", l4) 2137 2138 with eng.connect() as conn: 2139 conn.execute(select(1)) 2140 2141 eq_(canary, ["l1", "l2", "l3"]) 2142 2143 with eng1.connect() as conn: 2144 conn.execute(select(1)) 2145 2146 eq_(canary, ["l1", "l2", "l3", "l4", "l1", "l2", "l3"]) 2147 2148 canary[:] = [] 2149 2150 event.remove(Engine, "before_execute", l1) 2151 event.remove(eng1, "before_execute", l4) 2152 event.remove(eng, "before_execute", l3) 2153 2154 with eng1.connect() as conn: 2155 conn.execute(select(1)) 2156 eq_(canary, ["l2"]) 2157 2158 @testing.requires.ad_hoc_engines 2159 def test_cant_listen_to_option_engine(self): 2160 from sqlalchemy.engine import base 2161 2162 def evt(*arg, **kw): 2163 pass 2164 2165 assert_raises_message( 2166 tsa.exc.InvalidRequestError, 2167 r"Can't assign an event directly to the " 2168 "<class 'sqlalchemy.engine.base.OptionEngine'> class", 2169 event.listen, 2170 base.OptionEngine, 2171 "before_cursor_execute", 2172 evt, 2173 ) 2174 2175 @testing.requires.ad_hoc_engines 2176 def test_dispose_event(self, testing_engine): 2177 canary = Mock() 2178 eng = testing_engine(testing.db.url) 2179 event.listen(eng, "engine_disposed", canary) 2180 2181 conn = eng.connect() 2182 conn.close() 2183 eng.dispose() 2184 2185 conn = eng.connect() 2186 conn.close() 2187 2188 eq_(canary.mock_calls, [call(eng)]) 2189 2190 eng.dispose() 2191 2192 eq_(canary.mock_calls, [call(eng), call(eng)]) 2193 2194 def test_retval_flag(self): 2195 canary = [] 2196 2197 def tracker(name): 2198 def go(conn, *args, **kw): 2199 canary.append(name) 2200 2201 return go 2202 2203 def execute( 2204 conn, clauseelement, multiparams, params, execution_options 2205 ): 2206 canary.append("execute") 2207 return clauseelement, multiparams, params 2208 2209 def cursor_execute( 2210 conn, cursor, statement, parameters, context, executemany 2211 ): 2212 canary.append("cursor_execute") 2213 return statement, parameters 2214 2215 engine = engines.testing_engine() 2216 2217 assert_raises( 2218 tsa.exc.ArgumentError, 2219 event.listen, 2220 engine, 2221 "begin", 2222 tracker("begin"), 2223 retval=True, 2224 ) 2225 2226 event.listen(engine, "before_execute", execute, retval=True) 2227 event.listen( 2228 engine, "before_cursor_execute", cursor_execute, retval=True 2229 ) 2230 with engine.connect() as conn: 2231 conn.execute(select(1)) 2232 eq_(canary, ["execute", "cursor_execute"]) 2233 2234 @testing.requires.legacy_engine 2235 def test_engine_connect(self): 2236 engine = engines.testing_engine() 2237 2238 tracker = Mock() 2239 event.listen(engine, "engine_connect", tracker) 2240 2241 c1 = engine.connect() 2242 c2 = c1._branch() 2243 c1.close() 2244 eq_(tracker.mock_calls, [call(c1, False), call(c2, True)]) 2245 2246 def test_execution_options(self): 2247 engine = engines.testing_engine() 2248 2249 engine_tracker = Mock() 2250 conn_tracker = Mock() 2251 2252 event.listen(engine, "set_engine_execution_options", engine_tracker) 2253 event.listen(engine, "set_connection_execution_options", conn_tracker) 2254 2255 e2 = engine.execution_options(e1="opt_e1") 2256 c1 = engine.connect() 2257 c2 = c1.execution_options(c1="opt_c1") 2258 c3 = e2.connect() 2259 c4 = c3.execution_options(c3="opt_c3") 2260 eq_(engine_tracker.mock_calls, [call(e2, {"e1": "opt_e1"})]) 2261 eq_( 2262 conn_tracker.mock_calls, 2263 [call(c2, {"c1": "opt_c1"}), call(c4, {"c3": "opt_c3"})], 2264 ) 2265 2266 @testing.requires.sequences 2267 @testing.provide_metadata 2268 def test_cursor_execute(self): 2269 canary = [] 2270 2271 def tracker(name): 2272 def go(conn, cursor, statement, parameters, context, executemany): 2273 canary.append((statement, context)) 2274 2275 return go 2276 2277 engine = engines.testing_engine() 2278 2279 t = Table( 2280 "t", 2281 self.metadata, 2282 Column( 2283 "x", 2284 Integer, 2285 Sequence("t_id_seq"), 2286 primary_key=True, 2287 ), 2288 implicit_returning=False, 2289 ) 2290 self.metadata.create_all(engine) 2291 2292 with engine.begin() as conn: 2293 event.listen( 2294 conn, "before_cursor_execute", tracker("cursor_execute") 2295 ) 2296 conn.execute(t.insert()) 2297 2298 # we see the sequence pre-executed in the first call 2299 assert "t_id_seq" in canary[0][0] 2300 assert "INSERT" in canary[1][0] 2301 # same context 2302 is_(canary[0][1], canary[1][1]) 2303 2304 def test_transactional(self): 2305 canary = [] 2306 2307 def tracker(name): 2308 def go(conn, *args, **kw): 2309 canary.append(name) 2310 2311 return go 2312 2313 engine = engines.testing_engine() 2314 event.listen(engine, "before_execute", tracker("execute")) 2315 event.listen( 2316 engine, "before_cursor_execute", tracker("cursor_execute") 2317 ) 2318 event.listen(engine, "begin", tracker("begin")) 2319 event.listen(engine, "commit", tracker("commit")) 2320 event.listen(engine, "rollback", tracker("rollback")) 2321 2322 with engine.connect() as conn: 2323 trans = conn.begin() 2324 conn.execute(select(1)) 2325 trans.rollback() 2326 trans = conn.begin() 2327 conn.execute(select(1)) 2328 trans.commit() 2329 2330 eq_( 2331 canary, 2332 [ 2333 "begin", 2334 "execute", 2335 "cursor_execute", 2336 "rollback", 2337 "begin", 2338 "execute", 2339 "cursor_execute", 2340 "commit", 2341 ], 2342 ) 2343 2344 def test_transactional_named(self): 2345 canary = [] 2346 2347 def tracker(name): 2348 def go(*args, **kw): 2349 canary.append((name, set(kw))) 2350 2351 return go 2352 2353 engine = engines.testing_engine() 2354 event.listen(engine, "before_execute", tracker("execute"), named=True) 2355 event.listen( 2356 engine, 2357 "before_cursor_execute", 2358 tracker("cursor_execute"), 2359 named=True, 2360 ) 2361 event.listen(engine, "begin", tracker("begin"), named=True) 2362 event.listen(engine, "commit", tracker("commit"), named=True) 2363 event.listen(engine, "rollback", tracker("rollback"), named=True) 2364 2365 with engine.connect() as conn: 2366 trans = conn.begin() 2367 conn.execute(select(1)) 2368 trans.rollback() 2369 trans = conn.begin() 2370 conn.execute(select(1)) 2371 trans.commit() 2372 2373 eq_( 2374 canary, 2375 [ 2376 ("begin", set(["conn"])), 2377 ( 2378 "execute", 2379 set( 2380 [ 2381 "conn", 2382 "clauseelement", 2383 "multiparams", 2384 "params", 2385 "execution_options", 2386 ] 2387 ), 2388 ), 2389 ( 2390 "cursor_execute", 2391 set( 2392 [ 2393 "conn", 2394 "cursor", 2395 "executemany", 2396 "statement", 2397 "parameters", 2398 "context", 2399 ] 2400 ), 2401 ), 2402 ("rollback", set(["conn"])), 2403 ("begin", set(["conn"])), 2404 ( 2405 "execute", 2406 set( 2407 [ 2408 "conn", 2409 "clauseelement", 2410 "multiparams", 2411 "params", 2412 "execution_options", 2413 ] 2414 ), 2415 ), 2416 ( 2417 "cursor_execute", 2418 set( 2419 [ 2420 "conn", 2421 "cursor", 2422 "executemany", 2423 "statement", 2424 "parameters", 2425 "context", 2426 ] 2427 ), 2428 ), 2429 ("commit", set(["conn"])), 2430 ], 2431 ) 2432 2433 @testing.requires.savepoints 2434 @testing.requires.two_phase_transactions 2435 def test_transactional_advanced(self): 2436 canary1 = [] 2437 2438 def tracker1(name): 2439 def go(*args, **kw): 2440 canary1.append(name) 2441 2442 return go 2443 2444 canary2 = [] 2445 2446 def tracker2(name): 2447 def go(*args, **kw): 2448 canary2.append(name) 2449 2450 return go 2451 2452 engine = engines.testing_engine() 2453 for name in [ 2454 "begin", 2455 "savepoint", 2456 "rollback_savepoint", 2457 "release_savepoint", 2458 "rollback", 2459 "begin_twophase", 2460 "prepare_twophase", 2461 "commit_twophase", 2462 ]: 2463 event.listen(engine, "%s" % name, tracker1(name)) 2464 2465 conn = engine.connect() 2466 for name in [ 2467 "begin", 2468 "savepoint", 2469 "rollback_savepoint", 2470 "release_savepoint", 2471 "rollback", 2472 "begin_twophase", 2473 "prepare_twophase", 2474 "commit_twophase", 2475 ]: 2476 event.listen(conn, "%s" % name, tracker2(name)) 2477 2478 trans = conn.begin() 2479 trans2 = conn.begin_nested() 2480 conn.execute(select(1)) 2481 trans2.rollback() 2482 trans2 = conn.begin_nested() 2483 conn.execute(select(1)) 2484 trans2.commit() 2485 trans.rollback() 2486 2487 trans = conn.begin_twophase() 2488 conn.execute(select(1)) 2489 trans.prepare() 2490 trans.commit() 2491 2492 eq_( 2493 canary1, 2494 [ 2495 "begin", 2496 "savepoint", 2497 "rollback_savepoint", 2498 "savepoint", 2499 "release_savepoint", 2500 "rollback", 2501 "begin_twophase", 2502 "prepare_twophase", 2503 "commit_twophase", 2504 ], 2505 ) 2506 eq_( 2507 canary2, 2508 [ 2509 "begin", 2510 "savepoint", 2511 "rollback_savepoint", 2512 "savepoint", 2513 "release_savepoint", 2514 "rollback", 2515 "begin_twophase", 2516 "prepare_twophase", 2517 "commit_twophase", 2518 ], 2519 ) 2520 2521 2522class FutureEngineEventsTest(fixtures.FutureEngineMixin, EngineEventsTest): 2523 def test_future_fixture(self, testing_engine): 2524 e1 = testing_engine() 2525 2526 assert e1._is_future 2527 with e1.connect() as conn: 2528 assert conn._is_future 2529 2530 def test_emit_sql_in_autobegin(self, testing_engine): 2531 e1 = testing_engine(config.db_url) 2532 2533 canary = Mock() 2534 2535 @event.listens_for(e1, "begin") 2536 def begin(connection): 2537 result = connection.execute(select(1)).scalar() 2538 canary.got_result(result) 2539 2540 with e1.connect() as conn: 2541 assert conn._is_future 2542 conn.execute(select(1)).scalar() 2543 2544 assert conn.in_transaction() 2545 2546 conn.commit() 2547 2548 assert not conn.in_transaction() 2549 2550 eq_(canary.mock_calls, [call.got_result(1)]) 2551 2552 2553class HandleErrorTest(fixtures.TestBase): 2554 __requires__ = ("ad_hoc_engines",) 2555 __backend__ = True 2556 2557 def teardown_test(self): 2558 Engine.dispatch._clear() 2559 Engine._has_events = False 2560 2561 def test_handle_error(self): 2562 engine = engines.testing_engine() 2563 canary = Mock(return_value=None) 2564 2565 event.listen(engine, "handle_error", canary) 2566 2567 with engine.connect() as conn: 2568 try: 2569 conn.exec_driver_sql("SELECT FOO FROM I_DONT_EXIST") 2570 assert False 2571 except tsa.exc.DBAPIError as e: 2572 ctx = canary.mock_calls[0][1][0] 2573 2574 eq_(ctx.original_exception, e.orig) 2575 is_(ctx.sqlalchemy_exception, e) 2576 eq_(ctx.statement, "SELECT FOO FROM I_DONT_EXIST") 2577 2578 def test_exception_event_reraise(self): 2579 engine = engines.testing_engine() 2580 2581 class MyException(Exception): 2582 pass 2583 2584 @event.listens_for(engine, "handle_error", retval=True) 2585 def err(context): 2586 stmt = context.statement 2587 exception = context.original_exception 2588 if "ERROR ONE" in str(stmt): 2589 return MyException("my exception") 2590 elif "ERROR TWO" in str(stmt): 2591 return exception 2592 else: 2593 return None 2594 2595 conn = engine.connect() 2596 # case 1: custom exception 2597 assert_raises_message( 2598 MyException, 2599 "my exception", 2600 conn.exec_driver_sql, 2601 "SELECT 'ERROR ONE' FROM I_DONT_EXIST", 2602 ) 2603 # case 2: return the DBAPI exception we're given; 2604 # no wrapping should occur 2605 assert_raises( 2606 conn.dialect.dbapi.Error, 2607 conn.exec_driver_sql, 2608 "SELECT 'ERROR TWO' FROM I_DONT_EXIST", 2609 ) 2610 # case 3: normal wrapping 2611 assert_raises( 2612 tsa.exc.DBAPIError, 2613 conn.exec_driver_sql, 2614 "SELECT 'ERROR THREE' FROM I_DONT_EXIST", 2615 ) 2616 2617 def test_exception_event_reraise_chaining(self): 2618 engine = engines.testing_engine() 2619 2620 class MyException1(Exception): 2621 pass 2622 2623 class MyException2(Exception): 2624 pass 2625 2626 class MyException3(Exception): 2627 pass 2628 2629 @event.listens_for(engine, "handle_error", retval=True) 2630 def err1(context): 2631 stmt = context.statement 2632 2633 if ( 2634 "ERROR ONE" in str(stmt) 2635 or "ERROR TWO" in str(stmt) 2636 or "ERROR THREE" in str(stmt) 2637 ): 2638 return MyException1("my exception") 2639 elif "ERROR FOUR" in str(stmt): 2640 raise MyException3("my exception short circuit") 2641 2642 @event.listens_for(engine, "handle_error", retval=True) 2643 def err2(context): 2644 stmt = context.statement 2645 if ( 2646 "ERROR ONE" in str(stmt) or "ERROR FOUR" in str(stmt) 2647 ) and isinstance(context.chained_exception, MyException1): 2648 raise MyException2("my exception chained") 2649 elif "ERROR TWO" in str(stmt): 2650 return context.chained_exception 2651 else: 2652 return None 2653 2654 conn = engine.connect() 2655 2656 with patch.object( 2657 engine.dialect.execution_ctx_cls, "handle_dbapi_exception" 2658 ) as patched: 2659 assert_raises_message( 2660 MyException2, 2661 "my exception chained", 2662 conn.exec_driver_sql, 2663 "SELECT 'ERROR ONE' FROM I_DONT_EXIST", 2664 ) 2665 eq_(patched.call_count, 1) 2666 2667 with patch.object( 2668 engine.dialect.execution_ctx_cls, "handle_dbapi_exception" 2669 ) as patched: 2670 assert_raises( 2671 MyException1, 2672 conn.exec_driver_sql, 2673 "SELECT 'ERROR TWO' FROM I_DONT_EXIST", 2674 ) 2675 eq_(patched.call_count, 1) 2676 2677 with patch.object( 2678 engine.dialect.execution_ctx_cls, "handle_dbapi_exception" 2679 ) as patched: 2680 # test that non None from err1 isn't cancelled out 2681 # by err2 2682 assert_raises( 2683 MyException1, 2684 conn.exec_driver_sql, 2685 "SELECT 'ERROR THREE' FROM I_DONT_EXIST", 2686 ) 2687 eq_(patched.call_count, 1) 2688 2689 with patch.object( 2690 engine.dialect.execution_ctx_cls, "handle_dbapi_exception" 2691 ) as patched: 2692 assert_raises( 2693 tsa.exc.DBAPIError, 2694 conn.exec_driver_sql, 2695 "SELECT 'ERROR FIVE' FROM I_DONT_EXIST", 2696 ) 2697 eq_(patched.call_count, 1) 2698 2699 with patch.object( 2700 engine.dialect.execution_ctx_cls, "handle_dbapi_exception" 2701 ) as patched: 2702 assert_raises_message( 2703 MyException3, 2704 "my exception short circuit", 2705 conn.exec_driver_sql, 2706 "SELECT 'ERROR FOUR' FROM I_DONT_EXIST", 2707 ) 2708 eq_(patched.call_count, 1) 2709 2710 def test_exception_autorollback_fails(self): 2711 engine = engines.testing_engine() 2712 conn = engine.connect() 2713 2714 def boom(connection): 2715 raise engine.dialect.dbapi.OperationalError("rollback failed") 2716 2717 with expect_warnings( 2718 r"An exception has occurred during handling of a previous " 2719 r"exception. The previous exception " 2720 r"is.*(?:i_dont_exist|does not exist)", 2721 py2konly=True, 2722 ): 2723 with patch.object(conn.dialect, "do_rollback", boom): 2724 assert_raises_message( 2725 tsa.exc.OperationalError, 2726 "rollback failed", 2727 conn.exec_driver_sql, 2728 "insert into i_dont_exist (x) values ('y')", 2729 ) 2730 2731 def test_exception_event_ad_hoc_context(self): 2732 """test that handle_error is called with a context in 2733 cases where _handle_dbapi_error() is normally called without 2734 any context. 2735 2736 """ 2737 2738 engine = engines.testing_engine() 2739 2740 listener = Mock(return_value=None) 2741 event.listen(engine, "handle_error", listener) 2742 2743 nope = SomeException("nope") 2744 2745 class MyType(TypeDecorator): 2746 impl = Integer 2747 cache_ok = True 2748 2749 def process_bind_param(self, value, dialect): 2750 raise nope 2751 2752 with engine.connect() as conn: 2753 assert_raises_message( 2754 tsa.exc.StatementError, 2755 r"\(.*.SomeException\) " r"nope\n\[SQL\: u?SELECT 1 ", 2756 conn.execute, 2757 select(1).where(column("foo") == literal("bar", MyType())), 2758 ) 2759 2760 ctx = listener.mock_calls[0][1][0] 2761 assert ctx.statement.startswith("SELECT 1 ") 2762 is_(ctx.is_disconnect, False) 2763 is_(ctx.original_exception, nope) 2764 2765 def test_exception_event_non_dbapi_error(self): 2766 """test that handle_error is called with a context in 2767 cases where DBAPI raises an exception that is not a DBAPI 2768 exception, e.g. internal errors or encoding problems. 2769 2770 """ 2771 engine = engines.testing_engine() 2772 2773 listener = Mock(return_value=None) 2774 event.listen(engine, "handle_error", listener) 2775 2776 nope = TypeError("I'm not a DBAPI error") 2777 with engine.connect() as c: 2778 c.connection.cursor = Mock( 2779 return_value=Mock(execute=Mock(side_effect=nope)) 2780 ) 2781 2782 assert_raises_message( 2783 TypeError, 2784 "I'm not a DBAPI error", 2785 c.exec_driver_sql, 2786 "select ", 2787 ) 2788 ctx = listener.mock_calls[0][1][0] 2789 eq_(ctx.statement, "select ") 2790 is_(ctx.is_disconnect, False) 2791 is_(ctx.original_exception, nope) 2792 2793 def test_exception_event_disable_handlers(self): 2794 engine = engines.testing_engine() 2795 2796 class MyException1(Exception): 2797 pass 2798 2799 @event.listens_for(engine, "handle_error") 2800 def err1(context): 2801 stmt = context.statement 2802 2803 if "ERROR_ONE" in str(stmt): 2804 raise MyException1("my exception short circuit") 2805 2806 with engine.connect() as conn: 2807 assert_raises( 2808 tsa.exc.DBAPIError, 2809 conn.execution_options( 2810 skip_user_error_events=True 2811 ).exec_driver_sql, 2812 "SELECT ERROR_ONE FROM I_DONT_EXIST", 2813 ) 2814 2815 assert_raises( 2816 MyException1, 2817 conn.execution_options( 2818 skip_user_error_events=False 2819 ).exec_driver_sql, 2820 "SELECT ERROR_ONE FROM I_DONT_EXIST", 2821 ) 2822 2823 def _test_alter_disconnect(self, orig_error, evt_value): 2824 engine = engines.testing_engine() 2825 2826 @event.listens_for(engine, "handle_error") 2827 def evt(ctx): 2828 ctx.is_disconnect = evt_value 2829 2830 with patch.object( 2831 engine.dialect, "is_disconnect", Mock(return_value=orig_error) 2832 ): 2833 2834 with engine.connect() as c: 2835 try: 2836 c.exec_driver_sql("SELECT x FROM nonexistent") 2837 assert False 2838 except tsa.exc.StatementError as st: 2839 eq_(st.connection_invalidated, evt_value) 2840 2841 def test_alter_disconnect_to_true(self): 2842 self._test_alter_disconnect(False, True) 2843 self._test_alter_disconnect(True, True) 2844 2845 def test_alter_disconnect_to_false(self): 2846 self._test_alter_disconnect(True, False) 2847 self._test_alter_disconnect(False, False) 2848 2849 @testing.requires.independent_connections 2850 def _test_alter_invalidate_pool_to_false(self, set_to_false): 2851 orig_error = True 2852 2853 engine = engines.testing_engine() 2854 2855 @event.listens_for(engine, "handle_error") 2856 def evt(ctx): 2857 if set_to_false: 2858 ctx.invalidate_pool_on_disconnect = False 2859 2860 c1, c2, c3 = ( 2861 engine.pool.connect(), 2862 engine.pool.connect(), 2863 engine.pool.connect(), 2864 ) 2865 crecs = [conn._connection_record for conn in (c1, c2, c3)] 2866 c1.close() 2867 c2.close() 2868 c3.close() 2869 2870 with patch.object( 2871 engine.dialect, "is_disconnect", Mock(return_value=orig_error) 2872 ): 2873 2874 with engine.connect() as c: 2875 target_crec = c.connection._connection_record 2876 try: 2877 c.exec_driver_sql("SELECT x FROM nonexistent") 2878 assert False 2879 except tsa.exc.StatementError as st: 2880 eq_(st.connection_invalidated, True) 2881 2882 for crec in crecs: 2883 if crec is target_crec or not set_to_false: 2884 is_not(crec.dbapi_connection, crec.get_connection()) 2885 else: 2886 is_(crec.dbapi_connection, crec.get_connection()) 2887 2888 def test_alter_invalidate_pool_to_false(self): 2889 self._test_alter_invalidate_pool_to_false(True) 2890 2891 def test_alter_invalidate_pool_stays_true(self): 2892 self._test_alter_invalidate_pool_to_false(False) 2893 2894 def test_handle_error_event_connect_isolation_level(self): 2895 engine = engines.testing_engine() 2896 2897 class MySpecialException(Exception): 2898 pass 2899 2900 @event.listens_for(engine, "handle_error") 2901 def handle_error(ctx): 2902 raise MySpecialException("failed operation") 2903 2904 ProgrammingError = engine.dialect.dbapi.ProgrammingError 2905 with engine.connect() as conn: 2906 with patch.object( 2907 conn.dialect, 2908 "get_isolation_level", 2909 Mock(side_effect=ProgrammingError("random error")), 2910 ): 2911 assert_raises(MySpecialException, conn.get_isolation_level) 2912 2913 @testing.only_on("sqlite+pysqlite") 2914 def test_cursor_close_resultset_failed_connectionless(self): 2915 engine = engines.testing_engine() 2916 2917 the_conn = [] 2918 the_cursor = [] 2919 2920 @event.listens_for(engine, "after_cursor_execute") 2921 def go( 2922 connection, cursor, statement, parameters, context, executemany 2923 ): 2924 the_cursor.append(cursor) 2925 the_conn.append(connection) 2926 2927 with mock.patch( 2928 "sqlalchemy.engine.cursor.BaseCursorResult.__init__", 2929 Mock(side_effect=tsa.exc.InvalidRequestError("duplicate col")), 2930 ): 2931 with engine.connect() as conn: 2932 assert_raises( 2933 tsa.exc.InvalidRequestError, 2934 conn.execute, 2935 text("select 1"), 2936 ) 2937 2938 # cursor is closed 2939 assert_raises_message( 2940 engine.dialect.dbapi.ProgrammingError, 2941 "Cannot operate on a closed cursor", 2942 the_cursor[0].execute, 2943 "select 1", 2944 ) 2945 2946 # connection is closed 2947 assert the_conn[0].closed 2948 2949 @testing.only_on("sqlite+pysqlite") 2950 def test_cursor_close_resultset_failed_explicit(self): 2951 engine = engines.testing_engine() 2952 2953 the_cursor = [] 2954 2955 @event.listens_for(engine, "after_cursor_execute") 2956 def go( 2957 connection, cursor, statement, parameters, context, executemany 2958 ): 2959 the_cursor.append(cursor) 2960 2961 conn = engine.connect() 2962 2963 with mock.patch( 2964 "sqlalchemy.engine.cursor.BaseCursorResult.__init__", 2965 Mock(side_effect=tsa.exc.InvalidRequestError("duplicate col")), 2966 ): 2967 assert_raises( 2968 tsa.exc.InvalidRequestError, 2969 conn.execute, 2970 text("select 1"), 2971 ) 2972 2973 # cursor is closed 2974 assert_raises_message( 2975 engine.dialect.dbapi.ProgrammingError, 2976 "Cannot operate on a closed cursor", 2977 the_cursor[0].execute, 2978 "select 1", 2979 ) 2980 2981 # connection not closed 2982 assert not conn.closed 2983 2984 conn.close() 2985 2986 2987class OnConnectTest(fixtures.TestBase): 2988 __requires__ = ("sqlite",) 2989 2990 def setup_test(self): 2991 e = create_engine("sqlite://") 2992 2993 connection = Mock(get_server_version_info=Mock(return_value="5.0")) 2994 2995 def connect(*args, **kwargs): 2996 return connection 2997 2998 dbapi = Mock( 2999 sqlite_version_info=(99, 9, 9), 3000 version_info=(99, 9, 9), 3001 sqlite_version="99.9.9", 3002 paramstyle="named", 3003 connect=Mock(side_effect=connect), 3004 ) 3005 3006 sqlite3 = e.dialect.dbapi 3007 dbapi.Error = (sqlite3.Error,) 3008 dbapi.ProgrammingError = sqlite3.ProgrammingError 3009 3010 self.dbapi = dbapi 3011 self.ProgrammingError = sqlite3.ProgrammingError 3012 3013 def test_wraps_connect_in_dbapi(self): 3014 dbapi = self.dbapi 3015 dbapi.connect = Mock(side_effect=self.ProgrammingError("random error")) 3016 try: 3017 create_engine("sqlite://", module=dbapi).connect() 3018 assert False 3019 except tsa.exc.DBAPIError as de: 3020 assert not de.connection_invalidated 3021 3022 def test_handle_error_event_connect(self): 3023 dbapi = self.dbapi 3024 dbapi.connect = Mock(side_effect=self.ProgrammingError("random error")) 3025 3026 class MySpecialException(Exception): 3027 pass 3028 3029 eng = create_engine("sqlite://", module=dbapi) 3030 3031 @event.listens_for(eng, "handle_error") 3032 def handle_error(ctx): 3033 assert ctx.engine is eng 3034 assert ctx.connection is None 3035 raise MySpecialException("failed operation") 3036 3037 assert_raises(MySpecialException, eng.connect) 3038 3039 def test_handle_error_event_revalidate(self): 3040 dbapi = self.dbapi 3041 3042 class MySpecialException(Exception): 3043 pass 3044 3045 eng = create_engine("sqlite://", module=dbapi, _initialize=False) 3046 3047 @event.listens_for(eng, "handle_error") 3048 def handle_error(ctx): 3049 assert ctx.engine is eng 3050 assert ctx.connection is conn 3051 assert isinstance( 3052 ctx.sqlalchemy_exception, tsa.exc.ProgrammingError 3053 ) 3054 raise MySpecialException("failed operation") 3055 3056 conn = eng.connect() 3057 conn.invalidate() 3058 3059 dbapi.connect = Mock(side_effect=self.ProgrammingError("random error")) 3060 3061 assert_raises(MySpecialException, getattr, conn, "connection") 3062 3063 def test_handle_error_event_implicit_revalidate(self): 3064 dbapi = self.dbapi 3065 3066 class MySpecialException(Exception): 3067 pass 3068 3069 eng = create_engine("sqlite://", module=dbapi, _initialize=False) 3070 3071 @event.listens_for(eng, "handle_error") 3072 def handle_error(ctx): 3073 assert ctx.engine is eng 3074 assert ctx.connection is conn 3075 assert isinstance( 3076 ctx.sqlalchemy_exception, tsa.exc.ProgrammingError 3077 ) 3078 raise MySpecialException("failed operation") 3079 3080 conn = eng.connect() 3081 conn.invalidate() 3082 3083 dbapi.connect = Mock(side_effect=self.ProgrammingError("random error")) 3084 3085 assert_raises(MySpecialException, conn.execute, select(1)) 3086 3087 def test_handle_error_custom_connect(self): 3088 dbapi = self.dbapi 3089 3090 class MySpecialException(Exception): 3091 pass 3092 3093 def custom_connect(): 3094 raise self.ProgrammingError("random error") 3095 3096 eng = create_engine("sqlite://", module=dbapi, creator=custom_connect) 3097 3098 @event.listens_for(eng, "handle_error") 3099 def handle_error(ctx): 3100 assert ctx.engine is eng 3101 assert ctx.connection is None 3102 raise MySpecialException("failed operation") 3103 3104 assert_raises(MySpecialException, eng.connect) 3105 3106 def test_handle_error_event_connect_invalidate_flag(self): 3107 dbapi = self.dbapi 3108 dbapi.connect = Mock( 3109 side_effect=self.ProgrammingError( 3110 "Cannot operate on a closed database." 3111 ) 3112 ) 3113 3114 class MySpecialException(Exception): 3115 pass 3116 3117 eng = create_engine("sqlite://", module=dbapi) 3118 3119 @event.listens_for(eng, "handle_error") 3120 def handle_error(ctx): 3121 assert ctx.is_disconnect 3122 ctx.is_disconnect = False 3123 3124 try: 3125 eng.connect() 3126 assert False 3127 except tsa.exc.DBAPIError as de: 3128 assert not de.connection_invalidated 3129 3130 def test_cant_connect_stay_invalidated(self): 3131 class MySpecialException(Exception): 3132 pass 3133 3134 eng = create_engine("sqlite://") 3135 3136 @event.listens_for(eng, "handle_error") 3137 def handle_error(ctx): 3138 assert ctx.is_disconnect 3139 3140 conn = eng.connect() 3141 3142 conn.invalidate() 3143 3144 eng.pool._creator = Mock( 3145 side_effect=self.ProgrammingError( 3146 "Cannot operate on a closed database." 3147 ) 3148 ) 3149 3150 try: 3151 conn.connection 3152 assert False 3153 except tsa.exc.DBAPIError: 3154 assert conn.invalidated 3155 3156 def test_dont_touch_non_dbapi_exception_on_connect(self): 3157 dbapi = self.dbapi 3158 dbapi.connect = Mock(side_effect=TypeError("I'm not a DBAPI error")) 3159 3160 e = create_engine("sqlite://", module=dbapi) 3161 e.dialect.is_disconnect = is_disconnect = Mock() 3162 assert_raises_message(TypeError, "I'm not a DBAPI error", e.connect) 3163 eq_(is_disconnect.call_count, 0) 3164 3165 def test_ensure_dialect_does_is_disconnect_no_conn(self): 3166 """test that is_disconnect() doesn't choke if no connection, 3167 cursor given.""" 3168 dialect = testing.db.dialect 3169 dbapi = dialect.dbapi 3170 assert not dialect.is_disconnect( 3171 dbapi.OperationalError("test"), None, None 3172 ) 3173 3174 def test_invalidate_on_connect(self): 3175 """test that is_disconnect() is called during connect. 3176 3177 interpretation of connection failures are not supported by 3178 every backend. 3179 3180 """ 3181 dbapi = self.dbapi 3182 dbapi.connect = Mock( 3183 side_effect=self.ProgrammingError( 3184 "Cannot operate on a closed database." 3185 ) 3186 ) 3187 e = create_engine("sqlite://", module=dbapi) 3188 try: 3189 e.connect() 3190 assert False 3191 except tsa.exc.DBAPIError as de: 3192 assert de.connection_invalidated 3193 3194 @testing.only_on("sqlite+pysqlite") 3195 def test_initialize_connect_calls(self): 3196 """test for :ticket:`5497`, on_connect not called twice""" 3197 3198 m1 = Mock() 3199 cls_ = testing.db.dialect.__class__ 3200 3201 class SomeDialect(cls_): 3202 def initialize(self, connection): 3203 super(SomeDialect, self).initialize(connection) 3204 m1.initialize(connection) 3205 3206 def on_connect(self): 3207 oc = super(SomeDialect, self).on_connect() 3208 3209 def my_on_connect(conn): 3210 if oc: 3211 oc(conn) 3212 m1.on_connect(conn) 3213 3214 return my_on_connect 3215 3216 u1 = Mock( 3217 username=None, 3218 password=None, 3219 host=None, 3220 port=None, 3221 query={}, 3222 database=None, 3223 _instantiate_plugins=lambda kw: (u1, [], kw), 3224 _get_entrypoint=Mock( 3225 return_value=Mock(get_dialect_cls=lambda u: SomeDialect) 3226 ), 3227 ) 3228 eng = create_engine(u1, poolclass=QueuePool) 3229 # make sure other dialects aren't getting pulled in here 3230 eq_(eng.name, "sqlite") 3231 c = eng.connect() 3232 dbapi_conn_one = c.connection.dbapi_connection 3233 c.close() 3234 3235 eq_( 3236 m1.mock_calls, 3237 [call.on_connect(dbapi_conn_one), call.initialize(mock.ANY)], 3238 ) 3239 3240 c = eng.connect() 3241 3242 eq_( 3243 m1.mock_calls, 3244 [call.on_connect(dbapi_conn_one), call.initialize(mock.ANY)], 3245 ) 3246 3247 c2 = eng.connect() 3248 dbapi_conn_two = c2.connection.dbapi_connection 3249 3250 is_not(dbapi_conn_one, dbapi_conn_two) 3251 3252 eq_( 3253 m1.mock_calls, 3254 [ 3255 call.on_connect(dbapi_conn_one), 3256 call.initialize(mock.ANY), 3257 call.on_connect(dbapi_conn_two), 3258 ], 3259 ) 3260 3261 c.close() 3262 c2.close() 3263 3264 @testing.only_on("sqlite+pysqlite") 3265 def test_initialize_connect_race(self): 3266 """test for :ticket:`6337` fixing the regression in :ticket:`5497`, 3267 dialect init is mutexed""" 3268 3269 m1 = [] 3270 cls_ = testing.db.dialect.__class__ 3271 3272 class SomeDialect(cls_): 3273 def initialize(self, connection): 3274 super(SomeDialect, self).initialize(connection) 3275 m1.append("initialize") 3276 3277 def on_connect(self): 3278 oc = super(SomeDialect, self).on_connect() 3279 3280 def my_on_connect(conn): 3281 if oc: 3282 oc(conn) 3283 m1.append("on_connect") 3284 3285 return my_on_connect 3286 3287 u1 = Mock( 3288 username=None, 3289 password=None, 3290 host=None, 3291 port=None, 3292 query={}, 3293 database=None, 3294 _instantiate_plugins=lambda kw: (u1, [], kw), 3295 _get_entrypoint=Mock( 3296 return_value=Mock(get_dialect_cls=lambda u: SomeDialect) 3297 ), 3298 ) 3299 3300 for j in range(5): 3301 m1[:] = [] 3302 eng = create_engine( 3303 u1, 3304 poolclass=NullPool, 3305 connect_args={"check_same_thread": False}, 3306 ) 3307 3308 def go(): 3309 c = eng.connect() 3310 c.execute(text("select 1")) 3311 c.close() 3312 3313 threads = [threading.Thread(target=go) for i in range(10)] 3314 for t in threads: 3315 t.start() 3316 for t in threads: 3317 t.join() 3318 3319 eq_(m1, ["on_connect", "initialize"] + ["on_connect"] * 9) 3320 3321 3322class DialectEventTest(fixtures.TestBase): 3323 @contextmanager 3324 def _run_test(self, retval): 3325 m1 = Mock() 3326 3327 m1.do_execute.return_value = retval 3328 m1.do_executemany.return_value = retval 3329 m1.do_execute_no_params.return_value = retval 3330 e = engines.testing_engine(options={"_initialize": False}) 3331 3332 event.listen(e, "do_execute", m1.do_execute) 3333 event.listen(e, "do_executemany", m1.do_executemany) 3334 event.listen(e, "do_execute_no_params", m1.do_execute_no_params) 3335 3336 e.dialect.do_execute = m1.real_do_execute 3337 e.dialect.do_executemany = m1.real_do_executemany 3338 e.dialect.do_execute_no_params = m1.real_do_execute_no_params 3339 3340 def mock_the_cursor(cursor, *arg): 3341 arg[-1].get_result_proxy = Mock(return_value=Mock(context=arg[-1])) 3342 return retval 3343 3344 m1.real_do_execute.side_effect = ( 3345 m1.do_execute.side_effect 3346 ) = mock_the_cursor 3347 m1.real_do_executemany.side_effect = ( 3348 m1.do_executemany.side_effect 3349 ) = mock_the_cursor 3350 m1.real_do_execute_no_params.side_effect = ( 3351 m1.do_execute_no_params.side_effect 3352 ) = mock_the_cursor 3353 3354 with e.begin() as conn: 3355 yield conn, m1 3356 3357 def _assert(self, retval, m1, m2, mock_calls): 3358 eq_(m1.mock_calls, mock_calls) 3359 if retval: 3360 eq_(m2.mock_calls, []) 3361 else: 3362 eq_(m2.mock_calls, mock_calls) 3363 3364 def _test_do_execute(self, retval): 3365 with self._run_test(retval) as (conn, m1): 3366 result = conn.exec_driver_sql( 3367 "insert into table foo", {"foo": "bar"} 3368 ) 3369 self._assert( 3370 retval, 3371 m1.do_execute, 3372 m1.real_do_execute, 3373 [ 3374 call( 3375 result.context.cursor, 3376 "insert into table foo", 3377 {"foo": "bar"}, 3378 result.context, 3379 ) 3380 ], 3381 ) 3382 3383 def _test_do_executemany(self, retval): 3384 with self._run_test(retval) as (conn, m1): 3385 result = conn.exec_driver_sql( 3386 "insert into table foo", [{"foo": "bar"}, {"foo": "bar"}] 3387 ) 3388 self._assert( 3389 retval, 3390 m1.do_executemany, 3391 m1.real_do_executemany, 3392 [ 3393 call( 3394 result.context.cursor, 3395 "insert into table foo", 3396 [{"foo": "bar"}, {"foo": "bar"}], 3397 result.context, 3398 ) 3399 ], 3400 ) 3401 3402 def _test_do_execute_no_params(self, retval): 3403 with self._run_test(retval) as (conn, m1): 3404 result = conn.execution_options( 3405 no_parameters=True 3406 ).exec_driver_sql("insert into table foo") 3407 self._assert( 3408 retval, 3409 m1.do_execute_no_params, 3410 m1.real_do_execute_no_params, 3411 [ 3412 call( 3413 result.context.cursor, 3414 "insert into table foo", 3415 result.context, 3416 ) 3417 ], 3418 ) 3419 3420 def _test_cursor_execute(self, retval): 3421 with self._run_test(retval) as (conn, m1): 3422 dialect = conn.dialect 3423 3424 stmt = "insert into table foo" 3425 params = {"foo": "bar"} 3426 ctx = dialect.execution_ctx_cls._init_statement( 3427 dialect, 3428 conn, 3429 conn.connection, 3430 {}, 3431 stmt, 3432 [params], 3433 ) 3434 3435 conn._cursor_execute(ctx.cursor, stmt, params, ctx) 3436 3437 self._assert( 3438 retval, 3439 m1.do_execute, 3440 m1.real_do_execute, 3441 [call(ctx.cursor, "insert into table foo", {"foo": "bar"}, ctx)], 3442 ) 3443 3444 def test_do_execute_w_replace(self): 3445 self._test_do_execute(True) 3446 3447 def test_do_execute_wo_replace(self): 3448 self._test_do_execute(False) 3449 3450 def test_do_executemany_w_replace(self): 3451 self._test_do_executemany(True) 3452 3453 def test_do_executemany_wo_replace(self): 3454 self._test_do_executemany(False) 3455 3456 def test_do_execute_no_params_w_replace(self): 3457 self._test_do_execute_no_params(True) 3458 3459 def test_do_execute_no_params_wo_replace(self): 3460 self._test_do_execute_no_params(False) 3461 3462 def test_cursor_execute_w_replace(self): 3463 self._test_cursor_execute(True) 3464 3465 def test_cursor_execute_wo_replace(self): 3466 self._test_cursor_execute(False) 3467 3468 def test_connect_replace_params(self): 3469 e = engines.testing_engine(options={"_initialize": False}) 3470 3471 @event.listens_for(e, "do_connect") 3472 def evt(dialect, conn_rec, cargs, cparams): 3473 cargs[:] = ["foo", "hoho"] 3474 cparams.clear() 3475 cparams["bar"] = "bat" 3476 conn_rec.info["boom"] = "bap" 3477 3478 m1 = Mock() 3479 e.dialect.connect = m1.real_connect 3480 3481 with e.connect() as conn: 3482 eq_(m1.mock_calls, [call.real_connect("foo", "hoho", bar="bat")]) 3483 eq_(conn.info["boom"], "bap") 3484 3485 def test_connect_do_connect(self): 3486 e = engines.testing_engine(options={"_initialize": False}) 3487 3488 m1 = Mock() 3489 3490 @event.listens_for(e, "do_connect") 3491 def evt1(dialect, conn_rec, cargs, cparams): 3492 cargs[:] = ["foo", "hoho"] 3493 cparams.clear() 3494 cparams["bar"] = "bat" 3495 conn_rec.info["boom"] = "one" 3496 3497 @event.listens_for(e, "do_connect") 3498 def evt2(dialect, conn_rec, cargs, cparams): 3499 conn_rec.info["bap"] = "two" 3500 return m1.our_connect(cargs, cparams) 3501 3502 with e.connect() as conn: 3503 # called with args 3504 eq_( 3505 m1.mock_calls, 3506 [call.our_connect(["foo", "hoho"], {"bar": "bat"})], 3507 ) 3508 3509 eq_(conn.info["boom"], "one") 3510 eq_(conn.info["bap"], "two") 3511 3512 # returned our mock connection 3513 is_(conn.connection.dbapi_connection, m1.our_connect()) 3514 3515 def test_connect_do_connect_info_there_after_recycle(self): 3516 # test that info is maintained after the do_connect() 3517 # event for a soft invalidation. 3518 3519 e = engines.testing_engine(options={"_initialize": False}) 3520 3521 @event.listens_for(e, "do_connect") 3522 def evt1(dialect, conn_rec, cargs, cparams): 3523 conn_rec.info["boom"] = "one" 3524 3525 conn = e.connect() 3526 eq_(conn.info["boom"], "one") 3527 3528 conn.connection.invalidate(soft=True) 3529 conn.close() 3530 conn = e.connect() 3531 eq_(conn.info["boom"], "one") 3532 3533 def test_connect_do_connect_info_there_after_invalidate(self): 3534 # test that info is maintained after the do_connect() 3535 # event for a hard invalidation. 3536 3537 e = engines.testing_engine(options={"_initialize": False}) 3538 3539 @event.listens_for(e, "do_connect") 3540 def evt1(dialect, conn_rec, cargs, cparams): 3541 assert not conn_rec.info 3542 conn_rec.info["boom"] = "one" 3543 3544 conn = e.connect() 3545 eq_(conn.info["boom"], "one") 3546 3547 conn.connection.invalidate() 3548 conn = e.connect() 3549 eq_(conn.info["boom"], "one") 3550 3551 3552class FutureExecuteTest(fixtures.FutureEngineMixin, fixtures.TablesTest): 3553 __backend__ = True 3554 3555 @classmethod 3556 def define_tables(cls, metadata): 3557 Table( 3558 "users", 3559 metadata, 3560 Column("user_id", INT, primary_key=True, autoincrement=False), 3561 Column("user_name", VARCHAR(20)), 3562 test_needs_acid=True, 3563 ) 3564 Table( 3565 "users_autoinc", 3566 metadata, 3567 Column( 3568 "user_id", INT, primary_key=True, test_needs_autoincrement=True 3569 ), 3570 Column("user_name", VARCHAR(20)), 3571 test_needs_acid=True, 3572 ) 3573 3574 def test_non_dict_mapping(self, connection): 3575 """ensure arbitrary Mapping works for execute()""" 3576 3577 class NotADict(collections_abc.Mapping): 3578 def __init__(self, _data): 3579 self._data = _data 3580 3581 def __iter__(self): 3582 return iter(self._data) 3583 3584 def __len__(self): 3585 return len(self._data) 3586 3587 def __getitem__(self, key): 3588 return self._data[key] 3589 3590 def keys(self): 3591 return self._data.keys() 3592 3593 nd = NotADict({"a": 10, "b": 15}) 3594 eq_(dict(nd), {"a": 10, "b": 15}) 3595 3596 result = connection.execute( 3597 select( 3598 bindparam("a", type_=Integer), bindparam("b", type_=Integer) 3599 ), 3600 nd, 3601 ) 3602 eq_(result.first(), (10, 15)) 3603 3604 def test_row_works_as_mapping(self, connection): 3605 """ensure the RowMapping object works as a parameter dictionary for 3606 execute.""" 3607 3608 result = connection.execute( 3609 select(literal(10).label("a"), literal(15).label("b")) 3610 ) 3611 row = result.first() 3612 eq_(row, (10, 15)) 3613 eq_(row._mapping, {"a": 10, "b": 15}) 3614 3615 result = connection.execute( 3616 select( 3617 bindparam("a", type_=Integer).label("a"), 3618 bindparam("b", type_=Integer).label("b"), 3619 ), 3620 row._mapping, 3621 ) 3622 row = result.first() 3623 eq_(row, (10, 15)) 3624 eq_(row._mapping, {"a": 10, "b": 15}) 3625 3626 @testing.combinations( 3627 ({}, {}, {}), 3628 ({"a": "b"}, {}, {"a": "b"}), 3629 ({"a": "b", "d": "e"}, {"a": "c"}, {"a": "c", "d": "e"}), 3630 argnames="conn_opts, exec_opts, expected", 3631 ) 3632 def test_execution_opts_per_invoke( 3633 self, connection, conn_opts, exec_opts, expected 3634 ): 3635 opts = [] 3636 3637 @event.listens_for(connection, "before_cursor_execute") 3638 def before_cursor_execute( 3639 conn, cursor, statement, parameters, context, executemany 3640 ): 3641 opts.append(context.execution_options) 3642 3643 if conn_opts: 3644 connection = connection.execution_options(**conn_opts) 3645 3646 if exec_opts: 3647 connection.execute(select(1), execution_options=exec_opts) 3648 else: 3649 connection.execute(select(1)) 3650 3651 eq_(opts, [expected]) 3652 3653 @testing.combinations( 3654 ({}, {}, {}, {}), 3655 ({}, {"a": "b"}, {}, {"a": "b"}), 3656 ({}, {"a": "b", "d": "e"}, {"a": "c"}, {"a": "c", "d": "e"}), 3657 ( 3658 {"q": "z", "p": "r"}, 3659 {"a": "b", "p": "x", "d": "e"}, 3660 {"a": "c"}, 3661 {"q": "z", "p": "x", "a": "c", "d": "e"}, 3662 ), 3663 argnames="stmt_opts, conn_opts, exec_opts, expected", 3664 ) 3665 def test_execution_opts_per_invoke_execute_events( 3666 self, connection, stmt_opts, conn_opts, exec_opts, expected 3667 ): 3668 opts = [] 3669 3670 @event.listens_for(connection, "before_execute") 3671 def before_execute( 3672 conn, clauseelement, multiparams, params, execution_options 3673 ): 3674 opts.append(("before", execution_options)) 3675 3676 @event.listens_for(connection, "after_execute") 3677 def after_execute( 3678 conn, 3679 clauseelement, 3680 multiparams, 3681 params, 3682 execution_options, 3683 result, 3684 ): 3685 opts.append(("after", execution_options)) 3686 3687 stmt = select(1) 3688 3689 if stmt_opts: 3690 stmt = stmt.execution_options(**stmt_opts) 3691 3692 if conn_opts: 3693 connection = connection.execution_options(**conn_opts) 3694 3695 if exec_opts: 3696 connection.execute(stmt, execution_options=exec_opts) 3697 else: 3698 connection.execute(stmt) 3699 3700 eq_(opts, [("before", expected), ("after", expected)]) 3701 3702 def test_no_branching(self, connection): 3703 with testing.expect_deprecated( 3704 r"The Connection.connect\(\) method is considered legacy" 3705 ): 3706 assert_raises_message( 3707 NotImplementedError, 3708 "sqlalchemy.future.Connection does not support " 3709 "'branching' of new connections.", 3710 connection.connect, 3711 ) 3712 3713 3714class SetInputSizesTest(fixtures.TablesTest): 3715 __backend__ = True 3716 3717 __requires__ = ("independent_connections",) 3718 3719 @classmethod 3720 def define_tables(cls, metadata): 3721 Table( 3722 "users", 3723 metadata, 3724 Column("user_id", INT, primary_key=True, autoincrement=False), 3725 Column("user_name", VARCHAR(20)), 3726 ) 3727 3728 @testing.fixture 3729 def input_sizes_fixture(self, testing_engine): 3730 canary = mock.Mock() 3731 3732 def do_set_input_sizes(cursor, list_of_tuples, context): 3733 if not engine.dialect.positional: 3734 # sort by "user_id", "user_name", or otherwise 3735 # param name for a non-positional dialect, so that we can 3736 # confirm the ordering. mostly a py2 thing probably can't 3737 # occur on py3.6+ since we are passing dictionaries with 3738 # "user_id", "user_name" 3739 list_of_tuples = sorted( 3740 list_of_tuples, key=lambda elem: elem[0] 3741 ) 3742 canary.do_set_input_sizes(cursor, list_of_tuples, context) 3743 3744 def pre_exec(self): 3745 self.translate_set_input_sizes = None 3746 self.include_set_input_sizes = None 3747 self.exclude_set_input_sizes = None 3748 3749 engine = testing_engine() 3750 engine.connect().close() 3751 3752 # the idea of this test is we fully replace the dialect 3753 # do_set_input_sizes with a mock, and we can then intercept 3754 # the setting passed to the dialect. the test table uses very 3755 # "safe" datatypes so that the DBAPI does not actually need 3756 # setinputsizes() called in order to work. 3757 3758 with mock.patch.object( 3759 engine.dialect, "use_setinputsizes", True 3760 ), mock.patch.object( 3761 engine.dialect, "do_set_input_sizes", do_set_input_sizes 3762 ), mock.patch.object( 3763 engine.dialect.execution_ctx_cls, "pre_exec", pre_exec 3764 ): 3765 yield engine, canary 3766 3767 def test_set_input_sizes_no_event(self, input_sizes_fixture): 3768 engine, canary = input_sizes_fixture 3769 3770 with engine.begin() as conn: 3771 conn.execute( 3772 self.tables.users.insert(), 3773 [ 3774 {"user_id": 1, "user_name": "n1"}, 3775 {"user_id": 2, "user_name": "n2"}, 3776 ], 3777 ) 3778 3779 eq_( 3780 canary.mock_calls, 3781 [ 3782 call.do_set_input_sizes( 3783 mock.ANY, 3784 [ 3785 ( 3786 "user_id", 3787 mock.ANY, 3788 testing.eq_type_affinity(Integer), 3789 ), 3790 ( 3791 "user_name", 3792 mock.ANY, 3793 testing.eq_type_affinity(String), 3794 ), 3795 ], 3796 mock.ANY, 3797 ) 3798 ], 3799 ) 3800 3801 def test_set_input_sizes_expanding_param(self, input_sizes_fixture): 3802 engine, canary = input_sizes_fixture 3803 3804 with engine.connect() as conn: 3805 conn.execute( 3806 select(self.tables.users).where( 3807 self.tables.users.c.user_name.in_(["x", "y", "z"]) 3808 ) 3809 ) 3810 3811 eq_( 3812 canary.mock_calls, 3813 [ 3814 call.do_set_input_sizes( 3815 mock.ANY, 3816 [ 3817 ( 3818 "user_name_1_1", 3819 mock.ANY, 3820 testing.eq_type_affinity(String), 3821 ), 3822 ( 3823 "user_name_1_2", 3824 mock.ANY, 3825 testing.eq_type_affinity(String), 3826 ), 3827 ( 3828 "user_name_1_3", 3829 mock.ANY, 3830 testing.eq_type_affinity(String), 3831 ), 3832 ], 3833 mock.ANY, 3834 ) 3835 ], 3836 ) 3837 3838 @testing.requires.tuple_in 3839 def test_set_input_sizes_expanding_tuple_param(self, input_sizes_fixture): 3840 engine, canary = input_sizes_fixture 3841 3842 from sqlalchemy import tuple_ 3843 3844 with engine.connect() as conn: 3845 conn.execute( 3846 select(self.tables.users).where( 3847 tuple_( 3848 self.tables.users.c.user_id, 3849 self.tables.users.c.user_name, 3850 ).in_([(1, "x"), (2, "y")]) 3851 ) 3852 ) 3853 3854 eq_( 3855 canary.mock_calls, 3856 [ 3857 call.do_set_input_sizes( 3858 mock.ANY, 3859 [ 3860 ( 3861 "param_1_1_1", 3862 mock.ANY, 3863 testing.eq_type_affinity(Integer), 3864 ), 3865 ( 3866 "param_1_1_2", 3867 mock.ANY, 3868 testing.eq_type_affinity(String), 3869 ), 3870 ( 3871 "param_1_2_1", 3872 mock.ANY, 3873 testing.eq_type_affinity(Integer), 3874 ), 3875 ( 3876 "param_1_2_2", 3877 mock.ANY, 3878 testing.eq_type_affinity(String), 3879 ), 3880 ], 3881 mock.ANY, 3882 ) 3883 ], 3884 ) 3885 3886 def test_set_input_sizes_event(self, input_sizes_fixture): 3887 engine, canary = input_sizes_fixture 3888 3889 SPECIAL_STRING = mock.Mock() 3890 3891 @event.listens_for(engine, "do_setinputsizes") 3892 def do_setinputsizes( 3893 inputsizes, cursor, statement, parameters, context 3894 ): 3895 for k in inputsizes: 3896 if k.type._type_affinity is String: 3897 inputsizes[k] = ( 3898 SPECIAL_STRING, 3899 None, 3900 0, 3901 ) 3902 3903 with engine.begin() as conn: 3904 conn.execute( 3905 self.tables.users.insert(), 3906 [ 3907 {"user_id": 1, "user_name": "n1"}, 3908 {"user_id": 2, "user_name": "n2"}, 3909 ], 3910 ) 3911 3912 eq_( 3913 canary.mock_calls, 3914 [ 3915 call.do_set_input_sizes( 3916 mock.ANY, 3917 [ 3918 ( 3919 "user_id", 3920 mock.ANY, 3921 testing.eq_type_affinity(Integer), 3922 ), 3923 ( 3924 "user_name", 3925 (SPECIAL_STRING, None, 0), 3926 testing.eq_type_affinity(String), 3927 ), 3928 ], 3929 mock.ANY, 3930 ) 3931 ], 3932 ) 3933 3934 3935class DialectDoesntSupportCachingTest(fixtures.TestBase): 3936 """test the opt-in caching flag added in :ticket:`6184`.""" 3937 3938 __only_on__ = "sqlite+pysqlite" 3939 3940 __requires__ = ("sqlite_memory",) 3941 3942 @testing.fixture() 3943 def sqlite_no_cache_dialect(self, testing_engine): 3944 from sqlalchemy.dialects.sqlite.pysqlite import SQLiteDialect_pysqlite 3945 from sqlalchemy.dialects.sqlite.base import SQLiteCompiler 3946 from sqlalchemy.sql import visitors 3947 3948 class MyCompiler(SQLiteCompiler): 3949 def translate_select_structure(self, select_stmt, **kwargs): 3950 select = select_stmt 3951 3952 if not getattr(select, "_mydialect_visit", None): 3953 select = visitors.cloned_traverse(select_stmt, {}, {}) 3954 if select._limit_clause is not None: 3955 # create a bindparam with a fixed name and hardcode 3956 # it to the given limit. this breaks caching. 3957 select._limit_clause = bindparam( 3958 "limit", value=select._limit, literal_execute=True 3959 ) 3960 3961 select._mydialect_visit = True 3962 3963 return select 3964 3965 class MyDialect(SQLiteDialect_pysqlite): 3966 statement_compiler = MyCompiler 3967 3968 from sqlalchemy.dialects import registry 3969 3970 def go(name): 3971 return MyDialect 3972 3973 with mock.patch.object(registry, "load", go): 3974 eng = testing_engine() 3975 yield eng 3976 3977 @testing.fixture 3978 def data_fixture(self, sqlite_no_cache_dialect): 3979 m = MetaData() 3980 t = Table("t1", m, Column("x", Integer)) 3981 with sqlite_no_cache_dialect.begin() as conn: 3982 t.create(conn) 3983 conn.execute(t.insert(), [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}]) 3984 3985 return t 3986 3987 def test_no_cache(self, sqlite_no_cache_dialect, data_fixture): 3988 eng = sqlite_no_cache_dialect 3989 3990 def go(lim): 3991 with eng.connect() as conn: 3992 result = conn.execute( 3993 select(data_fixture).order_by(data_fixture.c.x).limit(lim) 3994 ) 3995 return result 3996 3997 r1 = go(2) 3998 r2 = go(3) 3999 4000 eq_(r1.all(), [(1,), (2,)]) 4001 eq_(r2.all(), [(1,), (2,), (3,)]) 4002 4003 def test_it_caches(self, sqlite_no_cache_dialect, data_fixture): 4004 eng = sqlite_no_cache_dialect 4005 eng.dialect.__class__.supports_statement_cache = True 4006 del eng.dialect.__dict__["_supports_statement_cache"] 4007 4008 def go(lim): 4009 with eng.connect() as conn: 4010 result = conn.execute( 4011 select(data_fixture).order_by(data_fixture.c.x).limit(lim) 4012 ) 4013 return result 4014 4015 r1 = go(2) 4016 r2 = go(3) 4017 4018 eq_(r1.all(), [(1,), (2,)]) 4019 4020 # wrong answer 4021 eq_( 4022 r2.all(), 4023 [ 4024 (1,), 4025 (2,), 4026 ], 4027 ) 4028