1import functools 2import typing 3from typing import Any, AsyncGenerator, Iterator 4 5import anyio 6 7try: 8 import contextvars # Python 3.7+ only or via contextvars backport. 9except ImportError: # pragma: no cover 10 contextvars = None # type: ignore 11 12 13T = typing.TypeVar("T") 14 15 16async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: 17 async with anyio.create_task_group() as task_group: 18 19 async def run(func: typing.Callable[[], typing.Coroutine]) -> None: 20 await func() 21 task_group.cancel_scope.cancel() 22 23 for func, kwargs in args: 24 task_group.start_soon(run, functools.partial(func, **kwargs)) 25 26 27async def run_in_threadpool( 28 func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any 29) -> T: 30 if contextvars is not None: # pragma: no cover 31 # Ensure we run in the same context 32 child = functools.partial(func, *args, **kwargs) 33 context = contextvars.copy_context() 34 func = context.run 35 args = (child,) 36 elif kwargs: # pragma: no cover 37 # run_sync doesn't accept 'kwargs', so bind them in here 38 func = functools.partial(func, **kwargs) 39 return await anyio.to_thread.run_sync(func, *args) 40 41 42class _StopIteration(Exception): 43 pass 44 45 46def _next(iterator: Iterator) -> Any: 47 # We can't raise `StopIteration` from within the threadpool iterator 48 # and catch it outside that context, so we coerce them into a different 49 # exception type. 50 try: 51 return next(iterator) 52 except StopIteration: 53 raise _StopIteration 54 55 56async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator: 57 while True: 58 try: 59 yield await anyio.to_thread.run_sync(_next, iterator) 60 except _StopIteration: 61 break 62