1import threading
2from asyncio import iscoroutine
3from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
4from contextlib import AbstractContextManager, contextmanager
5from types import TracebackType
6from typing import (
7    Any, AsyncContextManager, Callable, ContextManager, Coroutine, Dict, Generator, Iterable,
8    Optional, Tuple, Type, TypeVar, Union, cast, overload)
9from warnings import warn
10
11from ._core import _eventloop
12from ._core._eventloop import get_asynclib, get_cancelled_exc_class, threadlocals
13from ._core._synchronization import Event
14from ._core._tasks import CancelScope, create_task_group
15from .abc._tasks import TaskStatus
16
17T_Retval = TypeVar('T_Retval')
18T_co = TypeVar('T_co')
19
20
21def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval:
22    """
23    Call a coroutine function from a worker thread.
24
25    :param func: a coroutine function
26    :param args: positional arguments for the callable
27    :return: the return value of the coroutine function
28
29    """
30    try:
31        asynclib = threadlocals.current_async_module
32    except AttributeError:
33        raise RuntimeError('This function can only be run from an AnyIO worker thread')
34
35    return asynclib.run_async_from_thread(func, *args)
36
37
38def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]],
39                          *args: object) -> T_Retval:
40    warn('run_async_from_thread() has been deprecated, use anyio.from_thread.run() instead',
41         DeprecationWarning)
42    return run(func, *args)
43
44
45def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval:
46    """
47    Call a function in the event loop thread from a worker thread.
48
49    :param func: a callable
50    :param args: positional arguments for the callable
51    :return: the return value of the callable
52
53    """
54    try:
55        asynclib = threadlocals.current_async_module
56    except AttributeError:
57        raise RuntimeError('This function can only be run from an AnyIO worker thread')
58
59    return asynclib.run_sync_from_thread(func, *args)
60
61
62def run_sync_from_thread(func: Callable[..., T_Retval], *args: object) -> T_Retval:
63    warn('run_sync_from_thread() has been deprecated, use anyio.from_thread.run_sync() instead',
64         DeprecationWarning)
65    return run_sync(func, *args)
66
67
68class _BlockingAsyncContextManager(AbstractContextManager):
69    _enter_future: Future
70    _exit_future: Future
71    _exit_event: Event
72    _exit_exc_info: Tuple[Optional[Type[BaseException]], Optional[BaseException],
73                          Optional[TracebackType]] = (None, None, None)
74
75    def __init__(self, async_cm: AsyncContextManager[T_co], portal: 'BlockingPortal'):
76        self._async_cm = async_cm
77        self._portal = portal
78
79    async def run_async_cm(self) -> Optional[bool]:
80        try:
81            self._exit_event = Event()
82            value = await self._async_cm.__aenter__()
83        except BaseException as exc:
84            self._enter_future.set_exception(exc)
85            raise
86        else:
87            self._enter_future.set_result(value)
88
89        try:
90            # Wait for the sync context manager to exit.
91            # This next statement can raise `get_cancelled_exc_class()` if
92            # something went wrong in a task group in this async context
93            # manager.
94            await self._exit_event.wait()
95        finally:
96            # In case of cancellation, it could be that we end up here before
97            # `_BlockingAsyncContextManager.__exit__` is called, and an
98            # `_exit_exc_info` has been set.
99            result = await self._async_cm.__aexit__(*self._exit_exc_info)
100            return result
101
102    def __enter__(self) -> T_co:
103        self._enter_future = Future()
104        self._exit_future = self._portal.start_task_soon(self.run_async_cm)
105        cm = self._enter_future.result()
106        return cast(T_co, cm)
107
108    def __exit__(self, __exc_type: Optional[Type[BaseException]],
109                 __exc_value: Optional[BaseException],
110                 __traceback: Optional[TracebackType]) -> Optional[bool]:
111        self._exit_exc_info = __exc_type, __exc_value, __traceback
112        self._portal.call(self._exit_event.set)
113        return self._exit_future.result()
114
115
116class _BlockingPortalTaskStatus(TaskStatus):
117    def __init__(self, future: Future):
118        self._future = future
119
120    def started(self, value: object = None) -> None:
121        self._future.set_result(value)
122
123
124class BlockingPortal:
125    """An object that lets external threads run code in an asynchronous event loop."""
126
127    def __new__(cls) -> 'BlockingPortal':
128        return get_asynclib().BlockingPortal()
129
130    def __init__(self) -> None:
131        self._event_loop_thread_id: Optional[int] = threading.get_ident()
132        self._stop_event = Event()
133        self._task_group = create_task_group()
134        self._cancelled_exc_class = get_cancelled_exc_class()
135
136    async def __aenter__(self) -> 'BlockingPortal':
137        await self._task_group.__aenter__()
138        return self
139
140    async def __aexit__(self, exc_type: Optional[Type[BaseException]],
141                        exc_val: Optional[BaseException],
142                        exc_tb: Optional[TracebackType]) -> Optional[bool]:
143        await self.stop()
144        return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
145
146    def _check_running(self) -> None:
147        if self._event_loop_thread_id is None:
148            raise RuntimeError('This portal is not running')
149        if self._event_loop_thread_id == threading.get_ident():
150            raise RuntimeError('This method cannot be called from the event loop thread')
151
152    async def sleep_until_stopped(self) -> None:
153        """Sleep until :meth:`stop` is called."""
154        await self._stop_event.wait()
155
156    async def stop(self, cancel_remaining: bool = False) -> None:
157        """
158        Signal the portal to shut down.
159
160        This marks the portal as no longer accepting new calls and exits from
161        :meth:`sleep_until_stopped`.
162
163        :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False`` to let them
164            finish before returning
165
166        """
167        self._event_loop_thread_id = None
168        self._stop_event.set()
169        if cancel_remaining:
170            self._task_group.cancel_scope.cancel()
171
172    async def _call_func(self, func: Callable, args: tuple, kwargs: Dict[str, Any],
173                         future: Future) -> None:
174        def callback(f: Future) -> None:
175            if f.cancelled() and self._event_loop_thread_id not in (None, threading.get_ident()):
176                self.call(scope.cancel)
177
178        try:
179            retval = func(*args, **kwargs)
180            if iscoroutine(retval):
181                with CancelScope() as scope:
182                    if future.cancelled():
183                        scope.cancel()
184                    else:
185                        future.add_done_callback(callback)
186
187                    retval = await retval
188        except self._cancelled_exc_class:
189            future.cancel()
190        except BaseException as exc:
191            if not future.cancelled():
192                future.set_exception(exc)
193
194            # Let base exceptions fall through
195            if not isinstance(exc, Exception):
196                raise
197        else:
198            if not future.cancelled():
199                future.set_result(retval)
200        finally:
201            scope = None  # type: ignore[assignment]
202
203    def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any],
204                                name: object, future: Future) -> None:
205        """
206        Spawn a new task using the given callable.
207
208        Implementors must ensure that the future is resolved when the task finishes.
209
210        :param func: a callable
211        :param args: positional arguments to be passed to the callable
212        :param kwargs: keyword arguments to be passed to the callable
213        :param name: name of the task (will be coerced to a string if not ``None``)
214        :param future: a future that will resolve to the return value of the callable, or the
215            exception raised during its execution
216
217        """
218        raise NotImplementedError
219
220    @overload
221    def call(self, func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval:
222        ...
223
224    @overload
225    def call(self, func: Callable[..., T_Retval], *args: object) -> T_Retval:
226        ...
227
228    def call(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]],
229             *args: object) -> T_Retval:
230        """
231        Call the given function in the event loop thread.
232
233        If the callable returns a coroutine object, it is awaited on.
234
235        :param func: any callable
236        :raises RuntimeError: if the portal is not running or if this method is called from within
237            the event loop thread
238
239        """
240        return cast(T_Retval, self.start_task_soon(func, *args).result())
241
242    @overload
243    def spawn_task(self, func: Callable[..., Coroutine[Any, Any, T_Retval]],
244                   *args: object, name: object = None) -> "Future[T_Retval]":
245        ...
246
247    @overload
248    def spawn_task(self, func: Callable[..., T_Retval],
249                   *args: object, name: object = None) -> "Future[T_Retval]": ...
250
251    def spawn_task(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]],
252                   *args: object, name: object = None) -> "Future[T_Retval]":
253        """
254        Start a task in the portal's task group.
255
256        :param func: the target coroutine function
257        :param args: positional arguments passed to ``func``
258        :param name: name of the task (will be coerced to a string if not ``None``)
259        :return: a future that resolves with the return value of the callable if the task completes
260            successfully, or with the exception raised in the task
261        :raises RuntimeError: if the portal is not running or if this method is called from within
262            the event loop thread
263
264        .. versionadded:: 2.1
265        .. deprecated:: 3.0
266           Use :meth:`start_task_soon` instead. If your code needs AnyIO 2 compatibility, you
267           can keep using this until AnyIO 4.
268
269        """
270        warn('spawn_task() is deprecated -- use start_task_soon() instead', DeprecationWarning)
271        return self.start_task_soon(func, *args, name=name)  # type: ignore[arg-type]
272
273    @overload
274    def start_task_soon(self, func: Callable[..., Coroutine[Any, Any, T_Retval]],
275                        *args: object, name: object = None) -> "Future[T_Retval]":
276        ...
277
278    @overload
279    def start_task_soon(self, func: Callable[..., T_Retval],
280                        *args: object, name: object = None) -> "Future[T_Retval]": ...
281
282    def start_task_soon(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]],
283                        *args: object, name: object = None) -> "Future[T_Retval]":
284        """
285        Start a task in the portal's task group.
286
287        The task will be run inside a cancel scope which can be cancelled by cancelling the
288        returned future.
289
290        :param func: the target coroutine function
291        :param args: positional arguments passed to ``func``
292        :param name: name of the task (will be coerced to a string if not ``None``)
293        :return: a future that resolves with the return value of the callable if the task completes
294            successfully, or with the exception raised in the task
295        :raises RuntimeError: if the portal is not running or if this method is called from within
296            the event loop thread
297
298        .. versionadded:: 3.0
299
300        """
301        self._check_running()
302        f: Future = Future()
303        self._spawn_task_from_thread(func, args, {}, name, f)
304        return f
305
306    def start_task(self, func: Callable[..., Coroutine], *args: object,
307                   name: object = None) -> Tuple[Future, Any]:
308        """
309        Start a task in the portal's task group and wait until it signals for readiness.
310
311        This method works the same way as :meth:`TaskGroup.start`.
312
313        :param func: the target coroutine function
314        :param args: positional arguments passed to ``func``
315        :param name: name of the task (will be coerced to a string if not ``None``)
316        :return: a tuple of (future, task_status_value) where the ``task_status_value`` is the
317            value passed to ``task_status.started()`` from within the target function
318
319        .. versionadded:: 3.0
320
321        """
322        def task_done(future: Future) -> None:
323            if not task_status_future.done():
324                if future.cancelled():
325                    task_status_future.cancel()
326                elif future.exception():
327                    task_status_future.set_exception(future.exception())
328                else:
329                    exc = RuntimeError('Task exited without calling task_status.started()')
330                    task_status_future.set_exception(exc)
331
332        self._check_running()
333        task_status_future: Future = Future()
334        task_status = _BlockingPortalTaskStatus(task_status_future)
335        f: Future = Future()
336        f.add_done_callback(task_done)
337        self._spawn_task_from_thread(func, args, {'task_status': task_status}, name, f)
338        return f, task_status_future.result()
339
340    def wrap_async_context_manager(self, cm: AsyncContextManager[T_co]) -> ContextManager[T_co]:
341        """
342        Wrap an async context manager as a synchronous context manager via this portal.
343
344        Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping in the
345        middle until the synchronous context manager exits.
346
347        :param cm: an asynchronous context manager
348        :return: a synchronous context manager
349
350        .. versionadded:: 2.1
351
352        """
353        return _BlockingAsyncContextManager(cm, self)
354
355
356def create_blocking_portal() -> BlockingPortal:
357    """
358    Create a portal for running functions in the event loop thread from external threads.
359
360    Use this function in asynchronous code when you need to allow external threads access to the
361    event loop where your asynchronous code is currently running.
362
363    .. deprecated:: 3.0
364        Use :class:`.BlockingPortal` directly.
365
366    """
367    warn('create_blocking_portal() has been deprecated -- use anyio.from_thread.BlockingPortal() '
368         'directly', DeprecationWarning)
369    return BlockingPortal()
370
371
372@contextmanager
373def start_blocking_portal(
374        backend: str = 'asyncio',
375        backend_options: Optional[Dict[str, Any]] = None) -> Generator[BlockingPortal, Any, None]:
376    """
377    Start a new event loop in a new thread and run a blocking portal in its main task.
378
379    The parameters are the same as for :func:`~anyio.run`.
380
381    :param backend: name of the backend
382    :param backend_options: backend options
383    :return: a context manager that yields a blocking portal
384
385    .. versionchanged:: 3.0
386        Usage as a context manager is now required.
387
388    """
389    async def run_portal() -> None:
390        async with BlockingPortal() as portal_:
391            if future.set_running_or_notify_cancel():
392                future.set_result(portal_)
393                await portal_.sleep_until_stopped()
394
395    future: Future[BlockingPortal] = Future()
396    with ThreadPoolExecutor(1) as executor:
397        run_future = executor.submit(_eventloop.run, run_portal, backend=backend,
398                                     backend_options=backend_options)
399        try:
400            wait(cast(Iterable[Future], [run_future, future]), return_when=FIRST_COMPLETED)
401        except BaseException:
402            future.cancel()
403            run_future.cancel()
404            raise
405
406        if future.done():
407            portal = future.result()
408            try:
409                yield portal
410            except BaseException:
411                portal.call(portal.stop, True)
412                raise
413
414            portal.call(portal.stop, False)
415
416        run_future.result()
417