1import inspect
2import typing as t
3from functools import wraps
4
5from .utils import _PassArg
6from .utils import pass_eval_context
7
8V = t.TypeVar("V")
9
10
11def async_variant(normal_func):  # type: ignore
12    def decorator(async_func):  # type: ignore
13        pass_arg = _PassArg.from_obj(normal_func)
14        need_eval_context = pass_arg is None
15
16        if pass_arg is _PassArg.environment:
17
18            def is_async(args: t.Any) -> bool:
19                return t.cast(bool, args[0].is_async)
20
21        else:
22
23            def is_async(args: t.Any) -> bool:
24                return t.cast(bool, args[0].environment.is_async)
25
26        @wraps(normal_func)
27        def wrapper(*args, **kwargs):  # type: ignore
28            b = is_async(args)
29
30            if need_eval_context:
31                args = args[1:]
32
33            if b:
34                return async_func(*args, **kwargs)
35
36            return normal_func(*args, **kwargs)
37
38        if need_eval_context:
39            wrapper = pass_eval_context(wrapper)
40
41        wrapper.jinja_async_variant = True
42        return wrapper
43
44    return decorator
45
46
47async def auto_await(value: t.Union[t.Awaitable["V"], "V"]) -> "V":
48    if inspect.isawaitable(value):
49        return await t.cast("t.Awaitable[V]", value)
50
51    return t.cast("V", value)
52
53
54async def auto_aiter(
55    iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
56) -> "t.AsyncIterator[V]":
57    if hasattr(iterable, "__aiter__"):
58        async for item in t.cast("t.AsyncIterable[V]", iterable):
59            yield item
60    else:
61        for item in t.cast("t.Iterable[V]", iterable):
62            yield item
63
64
65async def auto_to_list(
66    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
67) -> t.List["V"]:
68    return [x async for x in auto_aiter(value)]
69