1import asyncio
2import typing
3
4from starlette.concurrency import run_in_threadpool
5
6
7class BackgroundTask:
8    def __init__(
9        self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
10    ) -> None:
11        self.func = func
12        self.args = args
13        self.kwargs = kwargs
14        self.is_async = asyncio.iscoroutinefunction(func)
15
16    async def __call__(self) -> None:
17        if self.is_async:
18            await self.func(*self.args, **self.kwargs)
19        else:
20            await run_in_threadpool(self.func, *self.args, **self.kwargs)
21
22
23class BackgroundTasks(BackgroundTask):
24    def __init__(self, tasks: typing.Sequence[BackgroundTask] = None):
25        self.tasks = list(tasks) if tasks else []
26
27    def add_task(
28        self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
29    ) -> None:
30        task = BackgroundTask(func, *args, **kwargs)
31        self.tasks.append(task)
32
33    async def __call__(self) -> None:
34        for task in self.tasks:
35            await task()
36