1# -*- coding: utf-8 -*-
2"""The code for async support. Importing this patches Jinja on supported
3Python versions.
4"""
5import asyncio
6import inspect
7from functools import update_wrapper
8
9from markupsafe import Markup
10
11from .environment import TemplateModule
12from .runtime import LoopContext
13from .utils import concat
14from .utils import internalcode
15from .utils import missing
16
17
18async def concat_async(async_gen):
19    rv = []
20
21    async def collect():
22        async for event in async_gen:
23            rv.append(event)
24
25    await collect()
26    return concat(rv)
27
28
29async def generate_async(self, *args, **kwargs):
30    vars = dict(*args, **kwargs)
31    try:
32        async for event in self.root_render_func(self.new_context(vars)):
33            yield event
34    except Exception:
35        yield self.environment.handle_exception()
36
37
38def wrap_generate_func(original_generate):
39    def _convert_generator(self, loop, args, kwargs):
40        async_gen = self.generate_async(*args, **kwargs)
41        try:
42            while 1:
43                yield loop.run_until_complete(async_gen.__anext__())
44        except StopAsyncIteration:
45            pass
46
47    def generate(self, *args, **kwargs):
48        if not self.environment.is_async:
49            return original_generate(self, *args, **kwargs)
50        return _convert_generator(self, asyncio.get_event_loop(), args, kwargs)
51
52    return update_wrapper(generate, original_generate)
53
54
55async def render_async(self, *args, **kwargs):
56    if not self.environment.is_async:
57        raise RuntimeError("The environment was not created with async mode enabled.")
58
59    vars = dict(*args, **kwargs)
60    ctx = self.new_context(vars)
61
62    try:
63        return await concat_async(self.root_render_func(ctx))
64    except Exception:
65        return self.environment.handle_exception()
66
67
68def wrap_render_func(original_render):
69    def render(self, *args, **kwargs):
70        if not self.environment.is_async:
71            return original_render(self, *args, **kwargs)
72        loop = asyncio.get_event_loop()
73        return loop.run_until_complete(self.render_async(*args, **kwargs))
74
75    return update_wrapper(render, original_render)
76
77
78def wrap_block_reference_call(original_call):
79    @internalcode
80    async def async_call(self):
81        rv = await concat_async(self._stack[self._depth](self._context))
82        if self._context.eval_ctx.autoescape:
83            rv = Markup(rv)
84        return rv
85
86    @internalcode
87    def __call__(self):
88        if not self._context.environment.is_async:
89            return original_call(self)
90        return async_call(self)
91
92    return update_wrapper(__call__, original_call)
93
94
95def wrap_macro_invoke(original_invoke):
96    @internalcode
97    async def async_invoke(self, arguments, autoescape):
98        rv = await self._func(*arguments)
99        if autoescape:
100            rv = Markup(rv)
101        return rv
102
103    @internalcode
104    def _invoke(self, arguments, autoescape):
105        if not self._environment.is_async:
106            return original_invoke(self, arguments, autoescape)
107        return async_invoke(self, arguments, autoescape)
108
109    return update_wrapper(_invoke, original_invoke)
110
111
112@internalcode
113async def get_default_module_async(self):
114    if self._module is not None:
115        return self._module
116    self._module = rv = await self.make_module_async()
117    return rv
118
119
120def wrap_default_module(original_default_module):
121    @internalcode
122    def _get_default_module(self):
123        if self.environment.is_async:
124            raise RuntimeError("Template module attribute is unavailable in async mode")
125        return original_default_module(self)
126
127    return _get_default_module
128
129
130async def make_module_async(self, vars=None, shared=False, locals=None):
131    context = self.new_context(vars, shared, locals)
132    body_stream = []
133    async for item in self.root_render_func(context):
134        body_stream.append(item)
135    return TemplateModule(self, context, body_stream)
136
137
138def patch_template():
139    from . import Template
140
141    Template.generate = wrap_generate_func(Template.generate)
142    Template.generate_async = update_wrapper(generate_async, Template.generate_async)
143    Template.render_async = update_wrapper(render_async, Template.render_async)
144    Template.render = wrap_render_func(Template.render)
145    Template._get_default_module = wrap_default_module(Template._get_default_module)
146    Template._get_default_module_async = get_default_module_async
147    Template.make_module_async = update_wrapper(
148        make_module_async, Template.make_module_async
149    )
150
151
152def patch_runtime():
153    from .runtime import BlockReference, Macro
154
155    BlockReference.__call__ = wrap_block_reference_call(BlockReference.__call__)
156    Macro._invoke = wrap_macro_invoke(Macro._invoke)
157
158
159def patch_filters():
160    from .filters import FILTERS
161    from .asyncfilters import ASYNC_FILTERS
162
163    FILTERS.update(ASYNC_FILTERS)
164
165
166def patch_all():
167    patch_template()
168    patch_runtime()
169    patch_filters()
170
171
172async def auto_await(value):
173    if inspect.isawaitable(value):
174        return await value
175    return value
176
177
178async def auto_aiter(iterable):
179    if hasattr(iterable, "__aiter__"):
180        async for item in iterable:
181            yield item
182        return
183    for item in iterable:
184        yield item
185
186
187class AsyncLoopContext(LoopContext):
188    _to_iterator = staticmethod(auto_aiter)
189
190    @property
191    async def length(self):
192        if self._length is not None:
193            return self._length
194
195        try:
196            self._length = len(self._iterable)
197        except TypeError:
198            iterable = [x async for x in self._iterator]
199            self._iterator = self._to_iterator(iterable)
200            self._length = len(iterable) + self.index + (self._after is not missing)
201
202        return self._length
203
204    @property
205    async def revindex0(self):
206        return await self.length - self.index
207
208    @property
209    async def revindex(self):
210        return await self.length - self.index0
211
212    async def _peek_next(self):
213        if self._after is not missing:
214            return self._after
215
216        try:
217            self._after = await self._iterator.__anext__()
218        except StopAsyncIteration:
219            self._after = missing
220
221        return self._after
222
223    @property
224    async def last(self):
225        return await self._peek_next() is missing
226
227    @property
228    async def nextitem(self):
229        rv = await self._peek_next()
230
231        if rv is missing:
232            return self._undefined("there is no next item")
233
234        return rv
235
236    def __aiter__(self):
237        return self
238
239    async def __anext__(self):
240        if self._after is not missing:
241            rv = self._after
242            self._after = missing
243        else:
244            rv = await self._iterator.__anext__()
245
246        self.index0 += 1
247        self._before = self._current
248        self._current = rv
249        return rv, self
250
251
252async def make_async_loop_context(iterable, undefined, recurse=None, depth0=0):
253    import warnings
254
255    warnings.warn(
256        "This template must be recompiled with at least Jinja 2.11, or"
257        " it will fail in 3.0.",
258        DeprecationWarning,
259        stacklevel=2,
260    )
261    return AsyncLoopContext(iterable, undefined, recurse, depth0)
262
263
264patch_all()
265