1import functools
2import typing
3from typing import Any, AsyncGenerator, Iterator
4
5import anyio
6
7try:
8    import contextvars  # Python 3.7+ only or via contextvars backport.
9except ImportError:  # pragma: no cover
10    contextvars = None  # type: ignore
11
12
13T = typing.TypeVar("T")
14
15
16async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
17    async with anyio.create_task_group() as task_group:
18
19        async def run(func: typing.Callable[[], typing.Coroutine]) -> None:
20            await func()
21            task_group.cancel_scope.cancel()
22
23        for func, kwargs in args:
24            task_group.start_soon(run, functools.partial(func, **kwargs))
25
26
27async def run_in_threadpool(
28    func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
29) -> T:
30    if contextvars is not None:  # pragma: no cover
31        # Ensure we run in the same context
32        child = functools.partial(func, *args, **kwargs)
33        context = contextvars.copy_context()
34        func = context.run
35        args = (child,)
36    elif kwargs:  # pragma: no cover
37        # run_sync doesn't accept 'kwargs', so bind them in here
38        func = functools.partial(func, **kwargs)
39    return await anyio.to_thread.run_sync(func, *args)
40
41
42class _StopIteration(Exception):
43    pass
44
45
46def _next(iterator: Iterator) -> Any:
47    # We can't raise `StopIteration` from within the threadpool iterator
48    # and catch it outside that context, so we coerce them into a different
49    # exception type.
50    try:
51        return next(iterator)
52    except StopIteration:
53        raise _StopIteration
54
55
56async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator:
57    while True:
58        try:
59            yield await anyio.to_thread.run_sync(_next, iterator)
60        except _StopIteration:
61            break
62