1from sqlalchemy import Column
2from sqlalchemy import event
3from sqlalchemy import exc
4from sqlalchemy import ForeignKey
5from sqlalchemy import func
6from sqlalchemy import inspect
7from sqlalchemy import Integer
8from sqlalchemy import select
9from sqlalchemy import Table
10from sqlalchemy import testing
11from sqlalchemy import update
12from sqlalchemy.ext.asyncio import async_object_session
13from sqlalchemy.ext.asyncio import AsyncSession
14from sqlalchemy.ext.asyncio.base import ReversibleProxy
15from sqlalchemy.orm import relationship
16from sqlalchemy.orm import selectinload
17from sqlalchemy.orm import Session
18from sqlalchemy.orm import sessionmaker
19from sqlalchemy.testing import async_test
20from sqlalchemy.testing import engines
21from sqlalchemy.testing import eq_
22from sqlalchemy.testing import is_
23from sqlalchemy.testing import is_true
24from sqlalchemy.testing import mock
25from .test_engine_py3k import AsyncFixture as _AsyncFixture
26from ...orm import _fixtures
27
28
29class AsyncFixture(_AsyncFixture, _fixtures.FixtureTest):
30    __requires__ = ("async_dialect",)
31
32    @classmethod
33    def setup_mappers(cls):
34        cls._setup_stock_mapping()
35
36    @testing.fixture
37    def async_engine(self):
38        return engines.testing_engine(asyncio=True, transfer_staticpool=True)
39
40    @testing.fixture
41    def async_session(self, async_engine):
42        return AsyncSession(async_engine)
43
44
45class AsyncSessionTest(AsyncFixture):
46    def test_requires_async_engine(self, async_engine):
47        testing.assert_raises_message(
48            exc.ArgumentError,
49            "AsyncEngine expected, got Engine",
50            AsyncSession,
51            bind=async_engine.sync_engine,
52        )
53
54    def test_info(self, async_session):
55        async_session.info["foo"] = "bar"
56
57        eq_(async_session.sync_session.info, {"foo": "bar"})
58
59    def test_init(self, async_engine):
60        ss = AsyncSession(bind=async_engine)
61        is_(ss.bind, async_engine)
62
63        binds = {Table: async_engine}
64        ss = AsyncSession(binds=binds)
65        is_(ss.binds, binds)
66
67
68class AsyncSessionQueryTest(AsyncFixture):
69    @async_test
70    @testing.combinations(
71        {}, dict(execution_options={"logging_token": "test"}), argnames="kw"
72    )
73    async def test_execute(self, async_session, kw):
74        User = self.classes.User
75
76        stmt = (
77            select(User)
78            .options(selectinload(User.addresses))
79            .order_by(User.id)
80        )
81
82        result = await async_session.execute(stmt, **kw)
83        eq_(result.scalars().all(), self.static.user_address_result)
84
85    @async_test
86    async def test_scalar(self, async_session):
87        User = self.classes.User
88
89        stmt = select(User.id).order_by(User.id).limit(1)
90
91        result = await async_session.scalar(stmt)
92        eq_(result, 7)
93
94    @testing.combinations(
95        ("scalars",), ("stream_scalars",), argnames="filter_"
96    )
97    @async_test
98    async def test_scalars(self, async_session, filter_):
99        User = self.classes.User
100
101        stmt = (
102            select(User)
103            .options(selectinload(User.addresses))
104            .order_by(User.id)
105        )
106
107        if filter_ == "scalars":
108            result = (await async_session.scalars(stmt)).all()
109        elif filter_ == "stream_scalars":
110            result = await (await async_session.stream_scalars(stmt)).all()
111        eq_(result, self.static.user_address_result)
112
113    @async_test
114    async def test_get(self, async_session):
115        User = self.classes.User
116
117        u1 = await async_session.get(User, 7)
118
119        eq_(u1.name, "jack")
120
121        u2 = await async_session.get(User, 7)
122
123        is_(u1, u2)
124
125        u3 = await async_session.get(User, 12)
126        is_(u3, None)
127
128    @async_test
129    async def test_get_loader_options(self, async_session):
130        User = self.classes.User
131
132        u = await async_session.get(
133            User, 7, options=[selectinload(User.addresses)]
134        )
135
136        eq_(u.name, "jack")
137        eq_(len(u.__dict__["addresses"]), 1)
138
139    @async_test
140    @testing.requires.independent_cursors
141    @testing.combinations(
142        {}, dict(execution_options={"logging_token": "test"}), argnames="kw"
143    )
144    async def test_stream_partitions(self, async_session, kw):
145        User = self.classes.User
146
147        stmt = (
148            select(User)
149            .options(selectinload(User.addresses))
150            .order_by(User.id)
151        )
152
153        result = await async_session.stream(stmt, **kw)
154
155        assert_result = []
156        async for partition in result.scalars().partitions(3):
157            assert_result.append(partition)
158
159        eq_(
160            assert_result,
161            [
162                self.static.user_address_result[0:3],
163                self.static.user_address_result[3:],
164            ],
165        )
166
167
168class AsyncSessionTransactionTest(AsyncFixture):
169    run_inserts = None
170
171    @async_test
172    async def test_interrupt_ctxmanager_connection(
173        self, async_trans_ctx_manager_fixture, async_session
174    ):
175        fn = async_trans_ctx_manager_fixture
176
177        await fn(async_session, trans_on_subject=True, execute_on_subject=True)
178
179    @async_test
180    async def test_sessionmaker_block_one(self, async_engine):
181
182        User = self.classes.User
183        maker = sessionmaker(async_engine, class_=AsyncSession)
184
185        session = maker()
186
187        async with session.begin():
188            u1 = User(name="u1")
189            assert session.in_transaction()
190            session.add(u1)
191
192        assert not session.in_transaction()
193
194        async with maker() as session:
195            result = await session.execute(
196                select(User).where(User.name == "u1")
197            )
198
199            u1 = result.scalar_one()
200
201            eq_(u1.name, "u1")
202
203    @async_test
204    async def test_sessionmaker_block_two(self, async_engine):
205
206        User = self.classes.User
207        maker = sessionmaker(async_engine, class_=AsyncSession)
208
209        async with maker.begin() as session:
210            u1 = User(name="u1")
211            assert session.in_transaction()
212            session.add(u1)
213
214        assert not session.in_transaction()
215
216        async with maker() as session:
217            result = await session.execute(
218                select(User).where(User.name == "u1")
219            )
220
221            u1 = result.scalar_one()
222
223            eq_(u1.name, "u1")
224
225    @async_test
226    async def test_trans(self, async_session, async_engine):
227        async with async_engine.connect() as outer_conn:
228
229            User = self.classes.User
230
231            async with async_session.begin():
232
233                eq_(await outer_conn.scalar(select(func.count(User.id))), 0)
234
235                u1 = User(name="u1")
236
237                async_session.add(u1)
238
239                result = await async_session.execute(select(User))
240                eq_(result.scalar(), u1)
241
242            await outer_conn.rollback()
243            eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
244
245    @async_test
246    async def test_commit_as_you_go(self, async_session, async_engine):
247        async with async_engine.connect() as outer_conn:
248
249            User = self.classes.User
250
251            eq_(await outer_conn.scalar(select(func.count(User.id))), 0)
252
253            u1 = User(name="u1")
254
255            async_session.add(u1)
256
257            result = await async_session.execute(select(User))
258            eq_(result.scalar(), u1)
259
260            await async_session.commit()
261
262            await outer_conn.rollback()
263            eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
264
265    @async_test
266    async def test_trans_noctx(self, async_session, async_engine):
267        async with async_engine.connect() as outer_conn:
268
269            User = self.classes.User
270
271            trans = await async_session.begin()
272            try:
273                eq_(await outer_conn.scalar(select(func.count(User.id))), 0)
274
275                u1 = User(name="u1")
276
277                async_session.add(u1)
278
279                result = await async_session.execute(select(User))
280                eq_(result.scalar(), u1)
281            finally:
282                await trans.commit()
283
284            await outer_conn.rollback()
285            eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
286
287    @async_test
288    async def test_delete(self, async_session):
289        User = self.classes.User
290
291        async with async_session.begin():
292            u1 = User(name="u1")
293
294            async_session.add(u1)
295
296            await async_session.flush()
297
298            conn = await async_session.connection()
299
300            eq_(await conn.scalar(select(func.count(User.id))), 1)
301
302            await async_session.delete(u1)
303
304            await async_session.flush()
305
306            eq_(await conn.scalar(select(func.count(User.id))), 0)
307
308    @async_test
309    async def test_flush(self, async_session):
310        User = self.classes.User
311
312        async with async_session.begin():
313            u1 = User(name="u1")
314
315            async_session.add(u1)
316
317            conn = await async_session.connection()
318
319            eq_(await conn.scalar(select(func.count(User.id))), 0)
320
321            await async_session.flush()
322
323            eq_(await conn.scalar(select(func.count(User.id))), 1)
324
325    @async_test
326    async def test_refresh(self, async_session):
327        User = self.classes.User
328
329        async with async_session.begin():
330            u1 = User(name="u1")
331
332            async_session.add(u1)
333            await async_session.flush()
334
335            conn = await async_session.connection()
336
337            await conn.execute(
338                update(User)
339                .values(name="u2")
340                .execution_options(synchronize_session=None)
341            )
342
343            eq_(u1.name, "u1")
344
345            await async_session.refresh(u1)
346
347            eq_(u1.name, "u2")
348
349            eq_(await conn.scalar(select(func.count(User.id))), 1)
350
351    @async_test
352    async def test_merge(self, async_session):
353        User = self.classes.User
354
355        async with async_session.begin():
356            u1 = User(id=1, name="u1")
357
358            async_session.add(u1)
359
360        async with async_session.begin():
361            new_u = User(id=1, name="new u1")
362
363            new_u_merged = await async_session.merge(new_u)
364
365            is_(new_u_merged, u1)
366            eq_(u1.name, "new u1")
367
368    @async_test
369    async def test_merge_loader_options(self, async_session):
370        User = self.classes.User
371        Address = self.classes.Address
372
373        async with async_session.begin():
374            u1 = User(id=1, name="u1", addresses=[Address(email_address="e1")])
375
376            async_session.add(u1)
377
378        await async_session.close()
379
380        async with async_session.begin():
381            new_u1 = User(id=1, name="new u1")
382
383            new_u_merged = await async_session.merge(
384                new_u1, options=[selectinload(User.addresses)]
385            )
386
387            eq_(new_u_merged.name, "new u1")
388            eq_(len(new_u_merged.__dict__["addresses"]), 1)
389
390    @async_test
391    async def test_join_to_external_transaction(self, async_engine):
392        User = self.classes.User
393
394        async with async_engine.connect() as conn:
395            t1 = await conn.begin()
396
397            async_session = AsyncSession(conn)
398
399            aconn = await async_session.connection()
400
401            eq_(aconn.get_transaction(), t1)
402
403            eq_(aconn, conn)
404            is_(aconn.sync_connection, conn.sync_connection)
405
406            u1 = User(id=1, name="u1")
407
408            async_session.add(u1)
409
410            await async_session.commit()
411
412            assert conn.in_transaction()
413            await conn.rollback()
414
415        async with AsyncSession(async_engine) as async_session:
416            result = await async_session.execute(select(User))
417            eq_(result.all(), [])
418
419    @testing.requires.savepoints
420    @async_test
421    async def test_join_to_external_transaction_with_savepoints(
422        self, async_engine
423    ):
424        """This is the full 'join to an external transaction' recipe
425        implemented for async using savepoints.
426
427        It's not particularly simple to understand as we have to switch between
428        async / sync APIs but it works and it's a start.
429
430        """
431
432        User = self.classes.User
433
434        async with async_engine.connect() as conn:
435
436            await conn.begin()
437
438            await conn.begin_nested()
439
440            async_session = AsyncSession(conn)
441
442            @event.listens_for(
443                async_session.sync_session, "after_transaction_end"
444            )
445            def end_savepoint(session, transaction):
446                """here's an event.  inside the event we write blocking
447                style code.    wow will this be fun to try to explain :)
448
449                """
450
451                if conn.closed:
452                    return
453
454                if not conn.in_nested_transaction():
455                    conn.sync_connection.begin_nested()
456
457            aconn = await async_session.connection()
458            is_(aconn.sync_connection, conn.sync_connection)
459
460            u1 = User(id=1, name="u1")
461
462            async_session.add(u1)
463
464            await async_session.commit()
465
466            result = (await async_session.execute(select(User))).all()
467            eq_(len(result), 1)
468
469            u2 = User(id=2, name="u2")
470            async_session.add(u2)
471
472            await async_session.flush()
473
474            result = (await async_session.execute(select(User))).all()
475            eq_(len(result), 2)
476
477            # a rollback inside the session ultimately ends the savepoint
478            await async_session.rollback()
479
480            # but the previous thing we "committed" is still in the DB
481            result = (await async_session.execute(select(User))).all()
482            eq_(len(result), 1)
483
484            assert conn.in_transaction()
485            await conn.rollback()
486
487        async with AsyncSession(async_engine) as async_session:
488            result = await async_session.execute(select(User))
489            eq_(result.all(), [])
490
491
492class AsyncCascadesTest(AsyncFixture):
493    run_inserts = None
494
495    @classmethod
496    def setup_mappers(cls):
497        User, Address = cls.classes("User", "Address")
498        users, addresses = cls.tables("users", "addresses")
499
500        cls.mapper(
501            User,
502            users,
503            properties={
504                "addresses": relationship(
505                    Address, cascade="all, delete-orphan"
506                )
507            },
508        )
509        cls.mapper(
510            Address,
511            addresses,
512        )
513
514    @async_test
515    async def test_delete_w_cascade(self, async_session):
516        User = self.classes.User
517        Address = self.classes.Address
518
519        async with async_session.begin():
520            u1 = User(id=1, name="u1", addresses=[Address(email_address="e1")])
521
522            async_session.add(u1)
523
524        async with async_session.begin():
525            u1 = (await async_session.execute(select(User))).scalar_one()
526
527            await async_session.delete(u1)
528
529        eq_(
530            (
531                await async_session.execute(
532                    select(func.count()).select_from(Address)
533                )
534            ).scalar(),
535            0,
536        )
537
538
539class AsyncORMBehaviorsTest(AsyncFixture):
540    @testing.fixture
541    def one_to_one_fixture(self, registry, async_engine):
542        async def go(legacy_inactive_history_style):
543            @registry.mapped
544            class A:
545                __tablename__ = "a"
546
547                id = Column(Integer, primary_key=True)
548                b = relationship(
549                    "B",
550                    uselist=False,
551                    _legacy_inactive_history_style=(
552                        legacy_inactive_history_style
553                    ),
554                )
555
556            @registry.mapped
557            class B:
558                __tablename__ = "b"
559                id = Column(Integer, primary_key=True)
560                a_id = Column(ForeignKey("a.id"))
561
562            async with async_engine.begin() as conn:
563                await conn.run_sync(registry.metadata.create_all)
564
565            return A, B
566
567        return go
568
569    @testing.combinations(
570        (
571            "legacy_style",
572            True,
573        ),
574        (
575            "new_style",
576            False,
577        ),
578        argnames="_legacy_inactive_history_style",
579        id_="ia",
580    )
581    @async_test
582    async def test_new_style_active_history(
583        self, async_session, one_to_one_fixture, _legacy_inactive_history_style
584    ):
585
586        A, B = await one_to_one_fixture(_legacy_inactive_history_style)
587
588        a1 = A()
589        b1 = B()
590
591        a1.b = b1
592        async_session.add(a1)
593
594        await async_session.commit()
595
596        b2 = B()
597
598        if _legacy_inactive_history_style:
599            # aiomysql dialect having problems here, emitting weird
600            # pytest warnings and we might need to just skip for aiomysql
601            # here, which is also raising StatementError w/ MissingGreenlet
602            # inside of it
603            with testing.expect_raises(
604                (exc.StatementError, exc.MissingGreenlet)
605            ):
606                a1.b = b2
607        else:
608            a1.b = b2
609
610            await async_session.flush()
611
612            await async_session.refresh(b1)
613
614            eq_(
615                (
616                    await async_session.execute(
617                        select(func.count())
618                        .where(B.id == b1.id)
619                        .where(B.a_id == None)
620                    )
621                ).scalar(),
622                1,
623            )
624
625
626class AsyncEventTest(AsyncFixture):
627    """The engine events all run in their normal synchronous context.
628
629    we do not provide an asyncio event interface at this time.
630
631    """
632
633    __backend__ = True
634
635    @async_test
636    async def test_no_async_listeners(self, async_session):
637        with testing.expect_raises_message(
638            NotImplementedError,
639            "asynchronous events are not implemented at this time.  "
640            "Apply synchronous listeners to the AsyncSession.sync_session.",
641        ):
642            event.listen(async_session, "before_flush", mock.Mock())
643
644    @async_test
645    async def test_sync_before_commit(self, async_session):
646        canary = mock.Mock()
647
648        event.listen(async_session.sync_session, "before_commit", canary)
649
650        async with async_session.begin():
651            pass
652
653        eq_(
654            canary.mock_calls,
655            [mock.call(async_session.sync_session)],
656        )
657
658
659class AsyncProxyTest(AsyncFixture):
660    @async_test
661    async def test_get_connection_engine_bound(self, async_session):
662        c1 = await async_session.connection()
663
664        c2 = await async_session.connection()
665
666        is_(c1, c2)
667        is_(c1.engine, c2.engine)
668
669    @async_test
670    async def test_get_connection_kws(self, async_session):
671        c1 = await async_session.connection(
672            execution_options={"isolation_level": "AUTOCOMMIT"}
673        )
674
675        eq_(
676            c1.sync_connection._execution_options,
677            {"isolation_level": "AUTOCOMMIT"},
678        )
679
680    @async_test
681    async def test_get_connection_connection_bound(self, async_engine):
682        async with async_engine.begin() as conn:
683            async_session = AsyncSession(conn)
684
685            c1 = await async_session.connection()
686
687            is_(c1, conn)
688            is_(c1.engine, conn.engine)
689
690    @async_test
691    async def test_get_transaction(self, async_session):
692
693        is_(async_session.get_transaction(), None)
694        is_(async_session.get_nested_transaction(), None)
695
696        t1 = await async_session.begin()
697
698        is_(async_session.get_transaction(), t1)
699        is_(async_session.get_nested_transaction(), None)
700
701        n1 = await async_session.begin_nested()
702
703        is_(async_session.get_transaction(), t1)
704        is_(async_session.get_nested_transaction(), n1)
705
706        await n1.commit()
707
708        is_(async_session.get_transaction(), t1)
709        is_(async_session.get_nested_transaction(), None)
710
711        await t1.commit()
712
713        is_(async_session.get_transaction(), None)
714        is_(async_session.get_nested_transaction(), None)
715
716    @async_test
717    async def test_async_object_session(self, async_engine):
718        User = self.classes.User
719
720        s1 = AsyncSession(async_engine)
721
722        s2 = AsyncSession(async_engine)
723
724        u1 = await s1.get(User, 7)
725
726        u2 = User(name="n1")
727
728        s2.add(u2)
729
730        u3 = User(name="n2")
731
732        is_(async_object_session(u1), s1)
733        is_(async_object_session(u2), s2)
734
735        is_(async_object_session(u3), None)
736
737        await s2.close()
738        is_(async_object_session(u2), None)
739
740    @async_test
741    async def test_async_object_session_custom(self, async_engine):
742        User = self.classes.User
743
744        class MyCustomAsync(AsyncSession):
745            pass
746
747        s1 = MyCustomAsync(async_engine)
748
749        u1 = await s1.get(User, 7)
750
751        assert isinstance(async_object_session(u1), MyCustomAsync)
752
753    @testing.requires.predictable_gc
754    @async_test
755    async def test_async_object_session_del(self, async_engine):
756        User = self.classes.User
757
758        s1 = AsyncSession(async_engine)
759
760        u1 = await s1.get(User, 7)
761
762        is_(async_object_session(u1), s1)
763
764        await s1.rollback()
765        del s1
766        is_(async_object_session(u1), None)
767
768    @async_test
769    async def test_inspect_session(self, async_engine):
770        User = self.classes.User
771
772        s1 = AsyncSession(async_engine)
773
774        s2 = AsyncSession(async_engine)
775
776        u1 = await s1.get(User, 7)
777
778        u2 = User(name="n1")
779
780        s2.add(u2)
781
782        u3 = User(name="n2")
783
784        is_(inspect(u1).async_session, s1)
785        is_(inspect(u2).async_session, s2)
786
787        is_(inspect(u3).async_session, None)
788
789    def test_inspect_session_no_asyncio_used(self):
790        User = self.classes.User
791
792        s1 = Session(testing.db)
793        u1 = s1.get(User, 7)
794
795        is_(inspect(u1).async_session, None)
796
797    def test_inspect_session_no_asyncio_imported(self):
798        with mock.patch("sqlalchemy.orm.state._async_provider", None):
799
800            User = self.classes.User
801
802            s1 = Session(testing.db)
803            u1 = s1.get(User, 7)
804
805            is_(inspect(u1).async_session, None)
806
807    @testing.requires.predictable_gc
808    def test_gc(self, async_engine):
809        ReversibleProxy._proxy_objects.clear()
810
811        eq_(len(ReversibleProxy._proxy_objects), 0)
812
813        async_session = AsyncSession(async_engine)
814
815        eq_(len(ReversibleProxy._proxy_objects), 1)
816
817        del async_session
818
819        eq_(len(ReversibleProxy._proxy_objects), 0)
820
821
822class _MySession(Session):
823    pass
824
825
826class _MyAS(AsyncSession):
827    sync_session_class = _MySession
828
829
830class OverrideSyncSession(AsyncFixture):
831    def test_default(self, async_engine):
832        ass = AsyncSession(async_engine)
833
834        is_true(isinstance(ass.sync_session, Session))
835        is_(ass.sync_session.__class__, Session)
836        is_(ass.sync_session_class, Session)
837
838    def test_init_class(self, async_engine):
839        ass = AsyncSession(async_engine, sync_session_class=_MySession)
840
841        is_true(isinstance(ass.sync_session, _MySession))
842        is_(ass.sync_session_class, _MySession)
843
844    def test_init_sessionmaker(self, async_engine):
845        sm = sessionmaker(
846            async_engine, class_=AsyncSession, sync_session_class=_MySession
847        )
848        ass = sm()
849
850        is_true(isinstance(ass.sync_session, _MySession))
851        is_(ass.sync_session_class, _MySession)
852
853    def test_subclass(self, async_engine):
854        ass = _MyAS(async_engine)
855
856        is_true(isinstance(ass.sync_session, _MySession))
857        is_(ass.sync_session_class, _MySession)
858
859    def test_subclass_override(self, async_engine):
860        ass = _MyAS(async_engine, sync_session_class=Session)
861
862        is_true(not isinstance(ass.sync_session, _MySession))
863        is_(ass.sync_session_class, Session)
864