1# ext/asyncio/session.py
2# Copyright (C) 2020-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7from . import engine
8from . import result as _result
9from .base import ReversibleProxy
10from .base import StartableContext
11from ... import util
12from ...orm import object_session
13from ...orm import Session
14from ...orm import state as _instance_state
15from ...util.concurrency import greenlet_spawn
16
17_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True})
18_STREAM_OPTIONS = util.immutabledict({"stream_results": True})
19
20
21@util.create_proxy_methods(
22    Session,
23    ":class:`_orm.Session`",
24    ":class:`_asyncio.AsyncSession`",
25    classmethods=["object_session", "identity_key"],
26    methods=[
27        "__contains__",
28        "__iter__",
29        "add",
30        "add_all",
31        "expire",
32        "expire_all",
33        "expunge",
34        "expunge_all",
35        "get_bind",
36        "is_modified",
37        "in_transaction",
38        "in_nested_transaction",
39    ],
40    attributes=[
41        "dirty",
42        "deleted",
43        "new",
44        "identity_map",
45        "is_active",
46        "autoflush",
47        "no_autoflush",
48        "info",
49    ],
50)
51class AsyncSession(ReversibleProxy):
52    """Asyncio version of :class:`_orm.Session`.
53
54    The :class:`_asyncio.AsyncSession` is a proxy for a traditional
55    :class:`_orm.Session` instance.
56
57    .. versionadded:: 1.4
58
59    To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session`
60    implementations, see the
61    :paramref:`_asyncio.AsyncSession.sync_session_class` parameter.
62
63
64    """
65
66    _is_asyncio = True
67
68    dispatch = None
69
70    def __init__(self, bind=None, binds=None, sync_session_class=None, **kw):
71        r"""Construct a new :class:`_asyncio.AsyncSession`.
72
73        All parameters other than ``sync_session_class`` are passed to the
74        ``sync_session_class`` callable directly to instantiate a new
75        :class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for
76        parameter documentation.
77
78        :param sync_session_class:
79          A :class:`_orm.Session` subclass or other callable which will be used
80          to construct the :class:`_orm.Session` which will be proxied. This
81          parameter may be used to provide custom :class:`_orm.Session`
82          subclasses. Defaults to the
83          :attr:`_asyncio.AsyncSession.sync_session_class` class-level
84          attribute.
85
86          .. versionadded:: 1.4.24
87
88        """
89        kw["future"] = True
90        if bind:
91            self.bind = bind
92            bind = engine._get_sync_engine_or_connection(bind)
93
94        if binds:
95            self.binds = binds
96            binds = {
97                key: engine._get_sync_engine_or_connection(b)
98                for key, b in binds.items()
99            }
100
101        if sync_session_class:
102            self.sync_session_class = sync_session_class
103
104        self.sync_session = self._proxied = self._assign_proxied(
105            self.sync_session_class(bind=bind, binds=binds, **kw)
106        )
107
108    sync_session_class = Session
109    """The class or callable that provides the
110    underlying :class:`_orm.Session` instance for a particular
111    :class:`_asyncio.AsyncSession`.
112
113    At the class level, this attribute is the default value for the
114    :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom
115    subclasses of :class:`_asyncio.AsyncSession` can override this.
116
117    At the instance level, this attribute indicates the current class or
118    callable that was used to provide the :class:`_orm.Session` instance for
119    this :class:`_asyncio.AsyncSession` instance.
120
121    .. versionadded:: 1.4.24
122
123    """
124
125    sync_session: Session
126    """Reference to the underlying :class:`_orm.Session` this
127    :class:`_asyncio.AsyncSession` proxies requests towards.
128
129    This instance can be used as an event target.
130
131    .. seealso::
132
133        :ref:`asyncio_events`
134
135    """
136
137    async def refresh(
138        self, instance, attribute_names=None, with_for_update=None
139    ):
140        """Expire and refresh the attributes on the given instance.
141
142        A query will be issued to the database and all attributes will be
143        refreshed with their current database value.
144
145        This is the async version of the :meth:`_orm.Session.refresh` method.
146        See that method for a complete description of all options.
147
148        .. seealso::
149
150            :meth:`_orm.Session.refresh` - main documentation for refresh
151
152        """
153
154        return await greenlet_spawn(
155            self.sync_session.refresh,
156            instance,
157            attribute_names=attribute_names,
158            with_for_update=with_for_update,
159        )
160
161    async def run_sync(self, fn, *arg, **kw):
162        """Invoke the given sync callable passing sync self as the first
163        argument.
164
165        This method maintains the asyncio event loop all the way through
166        to the database connection by running the given callable in a
167        specially instrumented greenlet.
168
169        E.g.::
170
171            with AsyncSession(async_engine) as session:
172                await session.run_sync(some_business_method)
173
174        .. note::
175
176            The provided callable is invoked inline within the asyncio event
177            loop, and will block on traditional IO calls.  IO within this
178            callable should only call into SQLAlchemy's asyncio database
179            APIs which will be properly adapted to the greenlet context.
180
181        .. seealso::
182
183            :ref:`session_run_sync`
184        """
185
186        return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
187
188    async def execute(
189        self,
190        statement,
191        params=None,
192        execution_options=util.EMPTY_DICT,
193        bind_arguments=None,
194        **kw
195    ):
196        """Execute a statement and return a buffered
197        :class:`_engine.Result` object.
198
199        .. seealso::
200
201            :meth:`_orm.Session.execute` - main documentation for execute
202
203        """
204
205        if execution_options:
206            execution_options = util.immutabledict(execution_options).union(
207                _EXECUTE_OPTIONS
208            )
209        else:
210            execution_options = _EXECUTE_OPTIONS
211
212        return await greenlet_spawn(
213            self.sync_session.execute,
214            statement,
215            params=params,
216            execution_options=execution_options,
217            bind_arguments=bind_arguments,
218            **kw
219        )
220
221    async def scalar(
222        self,
223        statement,
224        params=None,
225        execution_options=util.EMPTY_DICT,
226        bind_arguments=None,
227        **kw
228    ):
229        """Execute a statement and return a scalar result.
230
231        .. seealso::
232
233            :meth:`_orm.Session.scalar` - main documentation for scalar
234
235        """
236
237        result = await self.execute(
238            statement,
239            params=params,
240            execution_options=execution_options,
241            bind_arguments=bind_arguments,
242            **kw
243        )
244        return result.scalar()
245
246    async def scalars(
247        self,
248        statement,
249        params=None,
250        execution_options=util.EMPTY_DICT,
251        bind_arguments=None,
252        **kw
253    ):
254        """Execute a statement and return scalar results.
255
256        :return: a :class:`_result.ScalarResult` object
257
258        .. versionadded:: 1.4.24
259
260        .. seealso::
261
262            :meth:`_orm.Session.scalars` - main documentation for scalars
263
264            :meth:`_asyncio.AsyncSession.stream_scalars` - streaming version
265
266        """
267
268        result = await self.execute(
269            statement,
270            params=params,
271            execution_options=execution_options,
272            bind_arguments=bind_arguments,
273            **kw
274        )
275        return result.scalars()
276
277    async def get(
278        self,
279        entity,
280        ident,
281        options=None,
282        populate_existing=False,
283        with_for_update=None,
284        identity_token=None,
285    ):
286        """Return an instance based on the given primary key identifier,
287        or ``None`` if not found.
288
289        .. seealso::
290
291            :meth:`_orm.Session.get` - main documentation for get
292
293
294        """
295        return await greenlet_spawn(
296            self.sync_session.get,
297            entity,
298            ident,
299            options=options,
300            populate_existing=populate_existing,
301            with_for_update=with_for_update,
302            identity_token=identity_token,
303        )
304
305    async def stream(
306        self,
307        statement,
308        params=None,
309        execution_options=util.EMPTY_DICT,
310        bind_arguments=None,
311        **kw
312    ):
313        """Execute a statement and return a streaming
314        :class:`_asyncio.AsyncResult` object."""
315
316        if execution_options:
317            execution_options = util.immutabledict(execution_options).union(
318                _STREAM_OPTIONS
319            )
320        else:
321            execution_options = _STREAM_OPTIONS
322
323        result = await greenlet_spawn(
324            self.sync_session.execute,
325            statement,
326            params=params,
327            execution_options=execution_options,
328            bind_arguments=bind_arguments,
329            **kw
330        )
331        return _result.AsyncResult(result)
332
333    async def stream_scalars(
334        self,
335        statement,
336        params=None,
337        execution_options=util.EMPTY_DICT,
338        bind_arguments=None,
339        **kw
340    ):
341        """Execute a statement and return a stream of scalar results.
342
343        :return: an :class:`_asyncio.AsyncScalarResult` object
344
345        .. versionadded:: 1.4.24
346
347        .. seealso::
348
349            :meth:`_orm.Session.scalars` - main documentation for scalars
350
351            :meth:`_asyncio.AsyncSession.scalars` - non streaming version
352
353        """
354
355        result = await self.stream(
356            statement,
357            params=params,
358            execution_options=execution_options,
359            bind_arguments=bind_arguments,
360            **kw
361        )
362        return result.scalars()
363
364    async def delete(self, instance):
365        """Mark an instance as deleted.
366
367        The database delete operation occurs upon ``flush()``.
368
369        As this operation may need to cascade along unloaded relationships,
370        it is awaitable to allow for those queries to take place.
371
372        .. seealso::
373
374            :meth:`_orm.Session.delete` - main documentation for delete
375
376        """
377        return await greenlet_spawn(self.sync_session.delete, instance)
378
379    async def merge(self, instance, load=True, options=None):
380        """Copy the state of a given instance into a corresponding instance
381        within this :class:`_asyncio.AsyncSession`.
382
383        .. seealso::
384
385            :meth:`_orm.Session.merge` - main documentation for merge
386
387        """
388        return await greenlet_spawn(
389            self.sync_session.merge, instance, load=load, options=options
390        )
391
392    async def flush(self, objects=None):
393        """Flush all the object changes to the database.
394
395        .. seealso::
396
397            :meth:`_orm.Session.flush` - main documentation for flush
398
399        """
400        await greenlet_spawn(self.sync_session.flush, objects=objects)
401
402    def get_transaction(self):
403        """Return the current root transaction in progress, if any.
404
405        :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
406         ``None``.
407
408        .. versionadded:: 1.4.18
409
410        """
411        trans = self.sync_session.get_transaction()
412        if trans is not None:
413            return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
414        else:
415            return None
416
417    def get_nested_transaction(self):
418        """Return the current nested transaction in progress, if any.
419
420        :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
421         ``None``.
422
423        .. versionadded:: 1.4.18
424
425        """
426
427        trans = self.sync_session.get_nested_transaction()
428        if trans is not None:
429            return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
430        else:
431            return None
432
433    async def connection(self, **kw):
434        r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
435        this :class:`.Session` object's transactional state.
436
437        This method may also be used to establish execution options for the
438        database connection used by the current transaction.
439
440        .. versionadded:: 1.4.24  Added **kw arguments which are passed through
441           to the underlying :meth:`_orm.Session.connection` method.
442
443        .. seealso::
444
445            :meth:`_orm.Session.connection` - main documentation for
446            "connection"
447
448        """
449
450        sync_connection = await greenlet_spawn(
451            self.sync_session.connection, **kw
452        )
453        return engine.AsyncConnection._retrieve_proxy_for_target(
454            sync_connection
455        )
456
457    def begin(self, **kw):
458        """Return an :class:`_asyncio.AsyncSessionTransaction` object.
459
460        The underlying :class:`_orm.Session` will perform the
461        "begin" action when the :class:`_asyncio.AsyncSessionTransaction`
462        object is entered::
463
464            async with async_session.begin():
465                # .. ORM transaction is begun
466
467        Note that database IO will not normally occur when the session-level
468        transaction is begun, as database transactions begin on an
469        on-demand basis.  However, the begin block is async to accommodate
470        for a :meth:`_orm.SessionEvents.after_transaction_create`
471        event hook that may perform IO.
472
473        For a general description of ORM begin, see
474        :meth:`_orm.Session.begin`.
475
476        """
477
478        return AsyncSessionTransaction(self)
479
480    def begin_nested(self, **kw):
481        """Return an :class:`_asyncio.AsyncSessionTransaction` object
482        which will begin a "nested" transaction, e.g. SAVEPOINT.
483
484        Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`.
485
486        For a general description of ORM begin nested, see
487        :meth:`_orm.Session.begin_nested`.
488
489        """
490
491        return AsyncSessionTransaction(self, nested=True)
492
493    async def rollback(self):
494        """Rollback the current transaction in progress."""
495        return await greenlet_spawn(self.sync_session.rollback)
496
497    async def commit(self):
498        """Commit the current transaction in progress."""
499        return await greenlet_spawn(self.sync_session.commit)
500
501    async def close(self):
502        """Close out the transactional resources and ORM objects used by this
503        :class:`_asyncio.AsyncSession`.
504
505        This expunges all ORM objects associated with this
506        :class:`_asyncio.AsyncSession`, ends any transaction in progress and
507        :term:`releases` any :class:`_asyncio.AsyncConnection` objects which
508        this :class:`_asyncio.AsyncSession` itself has checked out from
509        associated :class:`_asyncio.AsyncEngine` objects. The operation then
510        leaves the :class:`_asyncio.AsyncSession` in a state which it may be
511        used again.
512
513        .. tip::
514
515            The :meth:`_asyncio.AsyncSession.close` method **does not prevent
516            the Session from being used again**. The
517            :class:`_asyncio.AsyncSession` itself does not actually have a
518            distinct "closed" state; it merely means the
519            :class:`_asyncio.AsyncSession` will release all database
520            connections and ORM objects.
521
522
523        .. seealso::
524
525            :ref:`session_closing` - detail on the semantics of
526            :meth:`_asyncio.AsyncSession.close`
527
528        """
529        return await greenlet_spawn(self.sync_session.close)
530
531    @classmethod
532    async def close_all(self):
533        """Close all :class:`_asyncio.AsyncSession` sessions."""
534        return await greenlet_spawn(self.sync_session.close_all)
535
536    async def __aenter__(self):
537        return self
538
539    async def __aexit__(self, type_, value, traceback):
540        await self.close()
541
542    def _maker_context_manager(self):
543        # no @contextlib.asynccontextmanager until python3.7, gr
544        return _AsyncSessionContextManager(self)
545
546
547class _AsyncSessionContextManager:
548    def __init__(self, async_session):
549        self.async_session = async_session
550
551    async def __aenter__(self):
552        self.trans = self.async_session.begin()
553        await self.trans.__aenter__()
554        return self.async_session
555
556    async def __aexit__(self, type_, value, traceback):
557        await self.trans.__aexit__(type_, value, traceback)
558        await self.async_session.__aexit__(type_, value, traceback)
559
560
561class AsyncSessionTransaction(ReversibleProxy, StartableContext):
562    """A wrapper for the ORM :class:`_orm.SessionTransaction` object.
563
564    This object is provided so that a transaction-holding object
565    for the :meth:`_asyncio.AsyncSession.begin` may be returned.
566
567    The object supports both explicit calls to
568    :meth:`_asyncio.AsyncSessionTransaction.commit` and
569    :meth:`_asyncio.AsyncSessionTransaction.rollback`, as well as use as an
570    async context manager.
571
572
573    .. versionadded:: 1.4
574
575    """
576
577    __slots__ = ("session", "sync_transaction", "nested")
578
579    def __init__(self, session, nested=False):
580        self.session = session
581        self.nested = nested
582        self.sync_transaction = None
583
584    @property
585    def is_active(self):
586        return (
587            self._sync_transaction() is not None
588            and self._sync_transaction().is_active
589        )
590
591    def _sync_transaction(self):
592        if not self.sync_transaction:
593            self._raise_for_not_started()
594        return self.sync_transaction
595
596    async def rollback(self):
597        """Roll back this :class:`_asyncio.AsyncTransaction`."""
598        await greenlet_spawn(self._sync_transaction().rollback)
599
600    async def commit(self):
601        """Commit this :class:`_asyncio.AsyncTransaction`."""
602
603        await greenlet_spawn(self._sync_transaction().commit)
604
605    async def start(self, is_ctxmanager=False):
606        self.sync_transaction = self._assign_proxied(
607            await greenlet_spawn(
608                self.session.sync_session.begin_nested
609                if self.nested
610                else self.session.sync_session.begin
611            )
612        )
613        if is_ctxmanager:
614            self.sync_transaction.__enter__()
615        return self
616
617    async def __aexit__(self, type_, value, traceback):
618        await greenlet_spawn(
619            self._sync_transaction().__exit__, type_, value, traceback
620        )
621
622
623def async_object_session(instance):
624    """Return the :class:`_asyncio.AsyncSession` to which the given instance
625    belongs.
626
627    This function makes use of the sync-API function
628    :class:`_orm.object_session` to retrieve the :class:`_orm.Session` which
629    refers to the given instance, and from there links it to the original
630    :class:`_asyncio.AsyncSession`.
631
632    If the :class:`_asyncio.AsyncSession` has been garbage collected, the
633    return value is ``None``.
634
635    This functionality is also available from the
636    :attr:`_orm.InstanceState.async_session` accessor.
637
638    :param instance: an ORM mapped instance
639    :return: an :class:`_asyncio.AsyncSession` object, or ``None``.
640
641    .. versionadded:: 1.4.18
642
643    """
644
645    session = object_session(instance)
646    if session is not None:
647        return async_session(session)
648    else:
649        return None
650
651
652def async_session(session):
653    """Return the :class:`_asyncio.AsyncSession` which is proxying the given
654    :class:`_orm.Session` object, if any.
655
656    :param session: a :class:`_orm.Session` instance.
657    :return: a :class:`_asyncio.AsyncSession` instance, or ``None``.
658
659    .. versionadded:: 1.4.18
660
661    """
662    return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
663
664
665_instance_state._async_provider = async_session
666