1from sqlalchemy import Column 2from sqlalchemy import event 3from sqlalchemy import exc 4from sqlalchemy import ForeignKey 5from sqlalchemy import func 6from sqlalchemy import inspect 7from sqlalchemy import Integer 8from sqlalchemy import select 9from sqlalchemy import Table 10from sqlalchemy import testing 11from sqlalchemy import update 12from sqlalchemy.ext.asyncio import async_object_session 13from sqlalchemy.ext.asyncio import AsyncSession 14from sqlalchemy.ext.asyncio.base import ReversibleProxy 15from sqlalchemy.orm import relationship 16from sqlalchemy.orm import selectinload 17from sqlalchemy.orm import Session 18from sqlalchemy.orm import sessionmaker 19from sqlalchemy.testing import async_test 20from sqlalchemy.testing import engines 21from sqlalchemy.testing import eq_ 22from sqlalchemy.testing import is_ 23from sqlalchemy.testing import is_true 24from sqlalchemy.testing import mock 25from .test_engine_py3k import AsyncFixture as _AsyncFixture 26from ...orm import _fixtures 27 28 29class AsyncFixture(_AsyncFixture, _fixtures.FixtureTest): 30 __requires__ = ("async_dialect",) 31 32 @classmethod 33 def setup_mappers(cls): 34 cls._setup_stock_mapping() 35 36 @testing.fixture 37 def async_engine(self): 38 return engines.testing_engine(asyncio=True, transfer_staticpool=True) 39 40 @testing.fixture 41 def async_session(self, async_engine): 42 return AsyncSession(async_engine) 43 44 45class AsyncSessionTest(AsyncFixture): 46 def test_requires_async_engine(self, async_engine): 47 testing.assert_raises_message( 48 exc.ArgumentError, 49 "AsyncEngine expected, got Engine", 50 AsyncSession, 51 bind=async_engine.sync_engine, 52 ) 53 54 def test_info(self, async_session): 55 async_session.info["foo"] = "bar" 56 57 eq_(async_session.sync_session.info, {"foo": "bar"}) 58 59 def test_init(self, async_engine): 60 ss = AsyncSession(bind=async_engine) 61 is_(ss.bind, async_engine) 62 63 binds = {Table: async_engine} 64 ss = AsyncSession(binds=binds) 65 is_(ss.binds, binds) 66 67 68class AsyncSessionQueryTest(AsyncFixture): 69 @async_test 70 @testing.combinations( 71 {}, dict(execution_options={"logging_token": "test"}), argnames="kw" 72 ) 73 async def test_execute(self, async_session, kw): 74 User = self.classes.User 75 76 stmt = ( 77 select(User) 78 .options(selectinload(User.addresses)) 79 .order_by(User.id) 80 ) 81 82 result = await async_session.execute(stmt, **kw) 83 eq_(result.scalars().all(), self.static.user_address_result) 84 85 @async_test 86 async def test_scalar(self, async_session): 87 User = self.classes.User 88 89 stmt = select(User.id).order_by(User.id).limit(1) 90 91 result = await async_session.scalar(stmt) 92 eq_(result, 7) 93 94 @testing.combinations( 95 ("scalars",), ("stream_scalars",), argnames="filter_" 96 ) 97 @async_test 98 async def test_scalars(self, async_session, filter_): 99 User = self.classes.User 100 101 stmt = ( 102 select(User) 103 .options(selectinload(User.addresses)) 104 .order_by(User.id) 105 ) 106 107 if filter_ == "scalars": 108 result = (await async_session.scalars(stmt)).all() 109 elif filter_ == "stream_scalars": 110 result = await (await async_session.stream_scalars(stmt)).all() 111 eq_(result, self.static.user_address_result) 112 113 @async_test 114 async def test_get(self, async_session): 115 User = self.classes.User 116 117 u1 = await async_session.get(User, 7) 118 119 eq_(u1.name, "jack") 120 121 u2 = await async_session.get(User, 7) 122 123 is_(u1, u2) 124 125 u3 = await async_session.get(User, 12) 126 is_(u3, None) 127 128 @async_test 129 async def test_get_loader_options(self, async_session): 130 User = self.classes.User 131 132 u = await async_session.get( 133 User, 7, options=[selectinload(User.addresses)] 134 ) 135 136 eq_(u.name, "jack") 137 eq_(len(u.__dict__["addresses"]), 1) 138 139 @async_test 140 @testing.requires.independent_cursors 141 @testing.combinations( 142 {}, dict(execution_options={"logging_token": "test"}), argnames="kw" 143 ) 144 async def test_stream_partitions(self, async_session, kw): 145 User = self.classes.User 146 147 stmt = ( 148 select(User) 149 .options(selectinload(User.addresses)) 150 .order_by(User.id) 151 ) 152 153 result = await async_session.stream(stmt, **kw) 154 155 assert_result = [] 156 async for partition in result.scalars().partitions(3): 157 assert_result.append(partition) 158 159 eq_( 160 assert_result, 161 [ 162 self.static.user_address_result[0:3], 163 self.static.user_address_result[3:], 164 ], 165 ) 166 167 168class AsyncSessionTransactionTest(AsyncFixture): 169 run_inserts = None 170 171 @async_test 172 async def test_interrupt_ctxmanager_connection( 173 self, async_trans_ctx_manager_fixture, async_session 174 ): 175 fn = async_trans_ctx_manager_fixture 176 177 await fn(async_session, trans_on_subject=True, execute_on_subject=True) 178 179 @async_test 180 async def test_sessionmaker_block_one(self, async_engine): 181 182 User = self.classes.User 183 maker = sessionmaker(async_engine, class_=AsyncSession) 184 185 session = maker() 186 187 async with session.begin(): 188 u1 = User(name="u1") 189 assert session.in_transaction() 190 session.add(u1) 191 192 assert not session.in_transaction() 193 194 async with maker() as session: 195 result = await session.execute( 196 select(User).where(User.name == "u1") 197 ) 198 199 u1 = result.scalar_one() 200 201 eq_(u1.name, "u1") 202 203 @async_test 204 async def test_sessionmaker_block_two(self, async_engine): 205 206 User = self.classes.User 207 maker = sessionmaker(async_engine, class_=AsyncSession) 208 209 async with maker.begin() as session: 210 u1 = User(name="u1") 211 assert session.in_transaction() 212 session.add(u1) 213 214 assert not session.in_transaction() 215 216 async with maker() as session: 217 result = await session.execute( 218 select(User).where(User.name == "u1") 219 ) 220 221 u1 = result.scalar_one() 222 223 eq_(u1.name, "u1") 224 225 @async_test 226 async def test_trans(self, async_session, async_engine): 227 async with async_engine.connect() as outer_conn: 228 229 User = self.classes.User 230 231 async with async_session.begin(): 232 233 eq_(await outer_conn.scalar(select(func.count(User.id))), 0) 234 235 u1 = User(name="u1") 236 237 async_session.add(u1) 238 239 result = await async_session.execute(select(User)) 240 eq_(result.scalar(), u1) 241 242 await outer_conn.rollback() 243 eq_(await outer_conn.scalar(select(func.count(User.id))), 1) 244 245 @async_test 246 async def test_commit_as_you_go(self, async_session, async_engine): 247 async with async_engine.connect() as outer_conn: 248 249 User = self.classes.User 250 251 eq_(await outer_conn.scalar(select(func.count(User.id))), 0) 252 253 u1 = User(name="u1") 254 255 async_session.add(u1) 256 257 result = await async_session.execute(select(User)) 258 eq_(result.scalar(), u1) 259 260 await async_session.commit() 261 262 await outer_conn.rollback() 263 eq_(await outer_conn.scalar(select(func.count(User.id))), 1) 264 265 @async_test 266 async def test_trans_noctx(self, async_session, async_engine): 267 async with async_engine.connect() as outer_conn: 268 269 User = self.classes.User 270 271 trans = await async_session.begin() 272 try: 273 eq_(await outer_conn.scalar(select(func.count(User.id))), 0) 274 275 u1 = User(name="u1") 276 277 async_session.add(u1) 278 279 result = await async_session.execute(select(User)) 280 eq_(result.scalar(), u1) 281 finally: 282 await trans.commit() 283 284 await outer_conn.rollback() 285 eq_(await outer_conn.scalar(select(func.count(User.id))), 1) 286 287 @async_test 288 async def test_delete(self, async_session): 289 User = self.classes.User 290 291 async with async_session.begin(): 292 u1 = User(name="u1") 293 294 async_session.add(u1) 295 296 await async_session.flush() 297 298 conn = await async_session.connection() 299 300 eq_(await conn.scalar(select(func.count(User.id))), 1) 301 302 await async_session.delete(u1) 303 304 await async_session.flush() 305 306 eq_(await conn.scalar(select(func.count(User.id))), 0) 307 308 @async_test 309 async def test_flush(self, async_session): 310 User = self.classes.User 311 312 async with async_session.begin(): 313 u1 = User(name="u1") 314 315 async_session.add(u1) 316 317 conn = await async_session.connection() 318 319 eq_(await conn.scalar(select(func.count(User.id))), 0) 320 321 await async_session.flush() 322 323 eq_(await conn.scalar(select(func.count(User.id))), 1) 324 325 @async_test 326 async def test_refresh(self, async_session): 327 User = self.classes.User 328 329 async with async_session.begin(): 330 u1 = User(name="u1") 331 332 async_session.add(u1) 333 await async_session.flush() 334 335 conn = await async_session.connection() 336 337 await conn.execute( 338 update(User) 339 .values(name="u2") 340 .execution_options(synchronize_session=None) 341 ) 342 343 eq_(u1.name, "u1") 344 345 await async_session.refresh(u1) 346 347 eq_(u1.name, "u2") 348 349 eq_(await conn.scalar(select(func.count(User.id))), 1) 350 351 @async_test 352 async def test_merge(self, async_session): 353 User = self.classes.User 354 355 async with async_session.begin(): 356 u1 = User(id=1, name="u1") 357 358 async_session.add(u1) 359 360 async with async_session.begin(): 361 new_u = User(id=1, name="new u1") 362 363 new_u_merged = await async_session.merge(new_u) 364 365 is_(new_u_merged, u1) 366 eq_(u1.name, "new u1") 367 368 @async_test 369 async def test_merge_loader_options(self, async_session): 370 User = self.classes.User 371 Address = self.classes.Address 372 373 async with async_session.begin(): 374 u1 = User(id=1, name="u1", addresses=[Address(email_address="e1")]) 375 376 async_session.add(u1) 377 378 await async_session.close() 379 380 async with async_session.begin(): 381 new_u1 = User(id=1, name="new u1") 382 383 new_u_merged = await async_session.merge( 384 new_u1, options=[selectinload(User.addresses)] 385 ) 386 387 eq_(new_u_merged.name, "new u1") 388 eq_(len(new_u_merged.__dict__["addresses"]), 1) 389 390 @async_test 391 async def test_join_to_external_transaction(self, async_engine): 392 User = self.classes.User 393 394 async with async_engine.connect() as conn: 395 t1 = await conn.begin() 396 397 async_session = AsyncSession(conn) 398 399 aconn = await async_session.connection() 400 401 eq_(aconn.get_transaction(), t1) 402 403 eq_(aconn, conn) 404 is_(aconn.sync_connection, conn.sync_connection) 405 406 u1 = User(id=1, name="u1") 407 408 async_session.add(u1) 409 410 await async_session.commit() 411 412 assert conn.in_transaction() 413 await conn.rollback() 414 415 async with AsyncSession(async_engine) as async_session: 416 result = await async_session.execute(select(User)) 417 eq_(result.all(), []) 418 419 @testing.requires.savepoints 420 @async_test 421 async def test_join_to_external_transaction_with_savepoints( 422 self, async_engine 423 ): 424 """This is the full 'join to an external transaction' recipe 425 implemented for async using savepoints. 426 427 It's not particularly simple to understand as we have to switch between 428 async / sync APIs but it works and it's a start. 429 430 """ 431 432 User = self.classes.User 433 434 async with async_engine.connect() as conn: 435 436 await conn.begin() 437 438 await conn.begin_nested() 439 440 async_session = AsyncSession(conn) 441 442 @event.listens_for( 443 async_session.sync_session, "after_transaction_end" 444 ) 445 def end_savepoint(session, transaction): 446 """here's an event. inside the event we write blocking 447 style code. wow will this be fun to try to explain :) 448 449 """ 450 451 if conn.closed: 452 return 453 454 if not conn.in_nested_transaction(): 455 conn.sync_connection.begin_nested() 456 457 aconn = await async_session.connection() 458 is_(aconn.sync_connection, conn.sync_connection) 459 460 u1 = User(id=1, name="u1") 461 462 async_session.add(u1) 463 464 await async_session.commit() 465 466 result = (await async_session.execute(select(User))).all() 467 eq_(len(result), 1) 468 469 u2 = User(id=2, name="u2") 470 async_session.add(u2) 471 472 await async_session.flush() 473 474 result = (await async_session.execute(select(User))).all() 475 eq_(len(result), 2) 476 477 # a rollback inside the session ultimately ends the savepoint 478 await async_session.rollback() 479 480 # but the previous thing we "committed" is still in the DB 481 result = (await async_session.execute(select(User))).all() 482 eq_(len(result), 1) 483 484 assert conn.in_transaction() 485 await conn.rollback() 486 487 async with AsyncSession(async_engine) as async_session: 488 result = await async_session.execute(select(User)) 489 eq_(result.all(), []) 490 491 492class AsyncCascadesTest(AsyncFixture): 493 run_inserts = None 494 495 @classmethod 496 def setup_mappers(cls): 497 User, Address = cls.classes("User", "Address") 498 users, addresses = cls.tables("users", "addresses") 499 500 cls.mapper( 501 User, 502 users, 503 properties={ 504 "addresses": relationship( 505 Address, cascade="all, delete-orphan" 506 ) 507 }, 508 ) 509 cls.mapper( 510 Address, 511 addresses, 512 ) 513 514 @async_test 515 async def test_delete_w_cascade(self, async_session): 516 User = self.classes.User 517 Address = self.classes.Address 518 519 async with async_session.begin(): 520 u1 = User(id=1, name="u1", addresses=[Address(email_address="e1")]) 521 522 async_session.add(u1) 523 524 async with async_session.begin(): 525 u1 = (await async_session.execute(select(User))).scalar_one() 526 527 await async_session.delete(u1) 528 529 eq_( 530 ( 531 await async_session.execute( 532 select(func.count()).select_from(Address) 533 ) 534 ).scalar(), 535 0, 536 ) 537 538 539class AsyncORMBehaviorsTest(AsyncFixture): 540 @testing.fixture 541 def one_to_one_fixture(self, registry, async_engine): 542 async def go(legacy_inactive_history_style): 543 @registry.mapped 544 class A: 545 __tablename__ = "a" 546 547 id = Column(Integer, primary_key=True) 548 b = relationship( 549 "B", 550 uselist=False, 551 _legacy_inactive_history_style=( 552 legacy_inactive_history_style 553 ), 554 ) 555 556 @registry.mapped 557 class B: 558 __tablename__ = "b" 559 id = Column(Integer, primary_key=True) 560 a_id = Column(ForeignKey("a.id")) 561 562 async with async_engine.begin() as conn: 563 await conn.run_sync(registry.metadata.create_all) 564 565 return A, B 566 567 return go 568 569 @testing.combinations( 570 ( 571 "legacy_style", 572 True, 573 ), 574 ( 575 "new_style", 576 False, 577 ), 578 argnames="_legacy_inactive_history_style", 579 id_="ia", 580 ) 581 @async_test 582 async def test_new_style_active_history( 583 self, async_session, one_to_one_fixture, _legacy_inactive_history_style 584 ): 585 586 A, B = await one_to_one_fixture(_legacy_inactive_history_style) 587 588 a1 = A() 589 b1 = B() 590 591 a1.b = b1 592 async_session.add(a1) 593 594 await async_session.commit() 595 596 b2 = B() 597 598 if _legacy_inactive_history_style: 599 # aiomysql dialect having problems here, emitting weird 600 # pytest warnings and we might need to just skip for aiomysql 601 # here, which is also raising StatementError w/ MissingGreenlet 602 # inside of it 603 with testing.expect_raises( 604 (exc.StatementError, exc.MissingGreenlet) 605 ): 606 a1.b = b2 607 else: 608 a1.b = b2 609 610 await async_session.flush() 611 612 await async_session.refresh(b1) 613 614 eq_( 615 ( 616 await async_session.execute( 617 select(func.count()) 618 .where(B.id == b1.id) 619 .where(B.a_id == None) 620 ) 621 ).scalar(), 622 1, 623 ) 624 625 626class AsyncEventTest(AsyncFixture): 627 """The engine events all run in their normal synchronous context. 628 629 we do not provide an asyncio event interface at this time. 630 631 """ 632 633 __backend__ = True 634 635 @async_test 636 async def test_no_async_listeners(self, async_session): 637 with testing.expect_raises_message( 638 NotImplementedError, 639 "asynchronous events are not implemented at this time. " 640 "Apply synchronous listeners to the AsyncSession.sync_session.", 641 ): 642 event.listen(async_session, "before_flush", mock.Mock()) 643 644 @async_test 645 async def test_sync_before_commit(self, async_session): 646 canary = mock.Mock() 647 648 event.listen(async_session.sync_session, "before_commit", canary) 649 650 async with async_session.begin(): 651 pass 652 653 eq_( 654 canary.mock_calls, 655 [mock.call(async_session.sync_session)], 656 ) 657 658 659class AsyncProxyTest(AsyncFixture): 660 @async_test 661 async def test_get_connection_engine_bound(self, async_session): 662 c1 = await async_session.connection() 663 664 c2 = await async_session.connection() 665 666 is_(c1, c2) 667 is_(c1.engine, c2.engine) 668 669 @async_test 670 async def test_get_connection_kws(self, async_session): 671 c1 = await async_session.connection( 672 execution_options={"isolation_level": "AUTOCOMMIT"} 673 ) 674 675 eq_( 676 c1.sync_connection._execution_options, 677 {"isolation_level": "AUTOCOMMIT"}, 678 ) 679 680 @async_test 681 async def test_get_connection_connection_bound(self, async_engine): 682 async with async_engine.begin() as conn: 683 async_session = AsyncSession(conn) 684 685 c1 = await async_session.connection() 686 687 is_(c1, conn) 688 is_(c1.engine, conn.engine) 689 690 @async_test 691 async def test_get_transaction(self, async_session): 692 693 is_(async_session.get_transaction(), None) 694 is_(async_session.get_nested_transaction(), None) 695 696 t1 = await async_session.begin() 697 698 is_(async_session.get_transaction(), t1) 699 is_(async_session.get_nested_transaction(), None) 700 701 n1 = await async_session.begin_nested() 702 703 is_(async_session.get_transaction(), t1) 704 is_(async_session.get_nested_transaction(), n1) 705 706 await n1.commit() 707 708 is_(async_session.get_transaction(), t1) 709 is_(async_session.get_nested_transaction(), None) 710 711 await t1.commit() 712 713 is_(async_session.get_transaction(), None) 714 is_(async_session.get_nested_transaction(), None) 715 716 @async_test 717 async def test_async_object_session(self, async_engine): 718 User = self.classes.User 719 720 s1 = AsyncSession(async_engine) 721 722 s2 = AsyncSession(async_engine) 723 724 u1 = await s1.get(User, 7) 725 726 u2 = User(name="n1") 727 728 s2.add(u2) 729 730 u3 = User(name="n2") 731 732 is_(async_object_session(u1), s1) 733 is_(async_object_session(u2), s2) 734 735 is_(async_object_session(u3), None) 736 737 await s2.close() 738 is_(async_object_session(u2), None) 739 740 @async_test 741 async def test_async_object_session_custom(self, async_engine): 742 User = self.classes.User 743 744 class MyCustomAsync(AsyncSession): 745 pass 746 747 s1 = MyCustomAsync(async_engine) 748 749 u1 = await s1.get(User, 7) 750 751 assert isinstance(async_object_session(u1), MyCustomAsync) 752 753 @testing.requires.predictable_gc 754 @async_test 755 async def test_async_object_session_del(self, async_engine): 756 User = self.classes.User 757 758 s1 = AsyncSession(async_engine) 759 760 u1 = await s1.get(User, 7) 761 762 is_(async_object_session(u1), s1) 763 764 await s1.rollback() 765 del s1 766 is_(async_object_session(u1), None) 767 768 @async_test 769 async def test_inspect_session(self, async_engine): 770 User = self.classes.User 771 772 s1 = AsyncSession(async_engine) 773 774 s2 = AsyncSession(async_engine) 775 776 u1 = await s1.get(User, 7) 777 778 u2 = User(name="n1") 779 780 s2.add(u2) 781 782 u3 = User(name="n2") 783 784 is_(inspect(u1).async_session, s1) 785 is_(inspect(u2).async_session, s2) 786 787 is_(inspect(u3).async_session, None) 788 789 def test_inspect_session_no_asyncio_used(self): 790 User = self.classes.User 791 792 s1 = Session(testing.db) 793 u1 = s1.get(User, 7) 794 795 is_(inspect(u1).async_session, None) 796 797 def test_inspect_session_no_asyncio_imported(self): 798 with mock.patch("sqlalchemy.orm.state._async_provider", None): 799 800 User = self.classes.User 801 802 s1 = Session(testing.db) 803 u1 = s1.get(User, 7) 804 805 is_(inspect(u1).async_session, None) 806 807 @testing.requires.predictable_gc 808 def test_gc(self, async_engine): 809 ReversibleProxy._proxy_objects.clear() 810 811 eq_(len(ReversibleProxy._proxy_objects), 0) 812 813 async_session = AsyncSession(async_engine) 814 815 eq_(len(ReversibleProxy._proxy_objects), 1) 816 817 del async_session 818 819 eq_(len(ReversibleProxy._proxy_objects), 0) 820 821 822class _MySession(Session): 823 pass 824 825 826class _MyAS(AsyncSession): 827 sync_session_class = _MySession 828 829 830class OverrideSyncSession(AsyncFixture): 831 def test_default(self, async_engine): 832 ass = AsyncSession(async_engine) 833 834 is_true(isinstance(ass.sync_session, Session)) 835 is_(ass.sync_session.__class__, Session) 836 is_(ass.sync_session_class, Session) 837 838 def test_init_class(self, async_engine): 839 ass = AsyncSession(async_engine, sync_session_class=_MySession) 840 841 is_true(isinstance(ass.sync_session, _MySession)) 842 is_(ass.sync_session_class, _MySession) 843 844 def test_init_sessionmaker(self, async_engine): 845 sm = sessionmaker( 846 async_engine, class_=AsyncSession, sync_session_class=_MySession 847 ) 848 ass = sm() 849 850 is_true(isinstance(ass.sync_session, _MySession)) 851 is_(ass.sync_session_class, _MySession) 852 853 def test_subclass(self, async_engine): 854 ass = _MyAS(async_engine) 855 856 is_true(isinstance(ass.sync_session, _MySession)) 857 is_(ass.sync_session_class, _MySession) 858 859 def test_subclass_override(self, async_engine): 860 ass = _MyAS(async_engine, sync_session_class=Session) 861 862 is_true(not isinstance(ass.sync_session, _MySession)) 863 is_(ass.sync_session_class, Session) 864