1# mako/codegen.py
2# Copyright 2006-2019 the Mako authors and contributors <see AUTHORS file>
3#
4# This module is part of Mako and is released under
5# the MIT License: http://www.opensource.org/licenses/mit-license.php
6
7"""provides functionality for rendering a parsetree constructing into module
8source code."""
9
10import re
11import time
12
13from mako import ast
14from mako import compat
15from mako import exceptions
16from mako import filters
17from mako import parsetree
18from mako import util
19from mako.pygen import PythonPrinter
20
21
22MAGIC_NUMBER = 10
23
24# names which are hardwired into the
25# template and are not accessed via the
26# context itself
27TOPLEVEL_DECLARED = set(["UNDEFINED", "STOP_RENDERING"])
28RESERVED_NAMES = set(["context", "loop"]).union(TOPLEVEL_DECLARED)
29
30
31def compile(  # noqa
32    node,
33    uri,
34    filename=None,
35    default_filters=None,
36    buffer_filters=None,
37    imports=None,
38    future_imports=None,
39    source_encoding=None,
40    generate_magic_comment=True,
41    disable_unicode=False,
42    strict_undefined=False,
43    enable_loop=True,
44    reserved_names=frozenset(),
45):
46    """Generate module source code given a parsetree node,
47      uri, and optional source filename"""
48
49    # if on Py2K, push the "source_encoding" string to be
50    # a bytestring itself, as we will be embedding it into
51    # the generated source and we don't want to coerce the
52    # result into a unicode object, in "disable_unicode" mode
53    if not compat.py3k and isinstance(source_encoding, compat.text_type):
54        source_encoding = source_encoding.encode(source_encoding)
55
56    buf = util.FastEncodingBuffer()
57
58    printer = PythonPrinter(buf)
59    _GenerateRenderMethod(
60        printer,
61        _CompileContext(
62            uri,
63            filename,
64            default_filters,
65            buffer_filters,
66            imports,
67            future_imports,
68            source_encoding,
69            generate_magic_comment,
70            disable_unicode,
71            strict_undefined,
72            enable_loop,
73            reserved_names,
74        ),
75        node,
76    )
77    return buf.getvalue()
78
79
80class _CompileContext(object):
81    def __init__(
82        self,
83        uri,
84        filename,
85        default_filters,
86        buffer_filters,
87        imports,
88        future_imports,
89        source_encoding,
90        generate_magic_comment,
91        disable_unicode,
92        strict_undefined,
93        enable_loop,
94        reserved_names,
95    ):
96        self.uri = uri
97        self.filename = filename
98        self.default_filters = default_filters
99        self.buffer_filters = buffer_filters
100        self.imports = imports
101        self.future_imports = future_imports
102        self.source_encoding = source_encoding
103        self.generate_magic_comment = generate_magic_comment
104        self.disable_unicode = disable_unicode
105        self.strict_undefined = strict_undefined
106        self.enable_loop = enable_loop
107        self.reserved_names = reserved_names
108
109
110class _GenerateRenderMethod(object):
111
112    """A template visitor object which generates the
113       full module source for a template.
114
115    """
116
117    def __init__(self, printer, compiler, node):
118        self.printer = printer
119        self.compiler = compiler
120        self.node = node
121        self.identifier_stack = [None]
122        self.in_def = isinstance(node, (parsetree.DefTag, parsetree.BlockTag))
123
124        if self.in_def:
125            name = "render_%s" % node.funcname
126            args = node.get_argument_expressions()
127            filtered = len(node.filter_args.args) > 0
128            buffered = eval(node.attributes.get("buffered", "False"))
129            cached = eval(node.attributes.get("cached", "False"))
130            defs = None
131            pagetag = None
132            if node.is_block and not node.is_anonymous:
133                args += ["**pageargs"]
134        else:
135            defs = self.write_toplevel()
136            pagetag = self.compiler.pagetag
137            name = "render_body"
138            if pagetag is not None:
139                args = pagetag.body_decl.get_argument_expressions()
140                if not pagetag.body_decl.kwargs:
141                    args += ["**pageargs"]
142                cached = eval(pagetag.attributes.get("cached", "False"))
143                self.compiler.enable_loop = self.compiler.enable_loop or eval(
144                    pagetag.attributes.get("enable_loop", "False")
145                )
146            else:
147                args = ["**pageargs"]
148                cached = False
149            buffered = filtered = False
150        if args is None:
151            args = ["context"]
152        else:
153            args = [a for a in ["context"] + args]
154
155        self.write_render_callable(
156            pagetag or node, name, args, buffered, filtered, cached
157        )
158
159        if defs is not None:
160            for node in defs:
161                _GenerateRenderMethod(printer, compiler, node)
162
163        if not self.in_def:
164            self.write_metadata_struct()
165
166    def write_metadata_struct(self):
167        self.printer.source_map[self.printer.lineno] = max(
168            self.printer.source_map
169        )
170        struct = {
171            "filename": self.compiler.filename,
172            "uri": self.compiler.uri,
173            "source_encoding": self.compiler.source_encoding,
174            "line_map": self.printer.source_map,
175        }
176        self.printer.writelines(
177            '"""',
178            "__M_BEGIN_METADATA",
179            compat.json.dumps(struct),
180            "__M_END_METADATA\n" '"""',
181        )
182
183    @property
184    def identifiers(self):
185        return self.identifier_stack[-1]
186
187    def write_toplevel(self):
188        """Traverse a template structure for module-level directives and
189        generate the start of module-level code.
190
191        """
192        inherit = []
193        namespaces = {}
194        module_code = []
195
196        self.compiler.pagetag = None
197
198        class FindTopLevel(object):
199            def visitInheritTag(s, node):
200                inherit.append(node)
201
202            def visitNamespaceTag(s, node):
203                namespaces[node.name] = node
204
205            def visitPageTag(s, node):
206                self.compiler.pagetag = node
207
208            def visitCode(s, node):
209                if node.ismodule:
210                    module_code.append(node)
211
212        f = FindTopLevel()
213        for n in self.node.nodes:
214            n.accept_visitor(f)
215
216        self.compiler.namespaces = namespaces
217
218        module_ident = set()
219        for n in module_code:
220            module_ident = module_ident.union(n.declared_identifiers())
221
222        module_identifiers = _Identifiers(self.compiler)
223        module_identifiers.declared = module_ident
224
225        # module-level names, python code
226        if (
227            self.compiler.generate_magic_comment
228            and self.compiler.source_encoding
229        ):
230            self.printer.writeline(
231                "# -*- coding:%s -*-" % self.compiler.source_encoding
232            )
233
234        if self.compiler.future_imports:
235            self.printer.writeline(
236                "from __future__ import %s"
237                % (", ".join(self.compiler.future_imports),)
238            )
239        self.printer.writeline("from mako import runtime, filters, cache")
240        self.printer.writeline("UNDEFINED = runtime.UNDEFINED")
241        self.printer.writeline("STOP_RENDERING = runtime.STOP_RENDERING")
242        self.printer.writeline("__M_dict_builtin = dict")
243        self.printer.writeline("__M_locals_builtin = locals")
244        self.printer.writeline("_magic_number = %r" % MAGIC_NUMBER)
245        self.printer.writeline("_modified_time = %r" % time.time())
246        self.printer.writeline("_enable_loop = %r" % self.compiler.enable_loop)
247        self.printer.writeline(
248            "_template_filename = %r" % self.compiler.filename
249        )
250        self.printer.writeline("_template_uri = %r" % self.compiler.uri)
251        self.printer.writeline(
252            "_source_encoding = %r" % self.compiler.source_encoding
253        )
254        if self.compiler.imports:
255            buf = ""
256            for imp in self.compiler.imports:
257                buf += imp + "\n"
258                self.printer.writeline(imp)
259            impcode = ast.PythonCode(
260                buf,
261                source="",
262                lineno=0,
263                pos=0,
264                filename="template defined imports",
265            )
266        else:
267            impcode = None
268
269        main_identifiers = module_identifiers.branch(self.node)
270        mit = module_identifiers.topleveldefs
271        module_identifiers.topleveldefs = mit.union(
272            main_identifiers.topleveldefs
273        )
274        module_identifiers.declared.update(TOPLEVEL_DECLARED)
275        if impcode:
276            module_identifiers.declared.update(impcode.declared_identifiers)
277
278        self.compiler.identifiers = module_identifiers
279        self.printer.writeline(
280            "_exports = %r"
281            % [n.name for n in main_identifiers.topleveldefs.values()]
282        )
283        self.printer.write_blanks(2)
284
285        if len(module_code):
286            self.write_module_code(module_code)
287
288        if len(inherit):
289            self.write_namespaces(namespaces)
290            self.write_inherit(inherit[-1])
291        elif len(namespaces):
292            self.write_namespaces(namespaces)
293
294        return list(main_identifiers.topleveldefs.values())
295
296    def write_render_callable(
297        self, node, name, args, buffered, filtered, cached
298    ):
299        """write a top-level render callable.
300
301        this could be the main render() method or that of a top-level def."""
302
303        if self.in_def:
304            decorator = node.decorator
305            if decorator:
306                self.printer.writeline(
307                    "@runtime._decorate_toplevel(%s)" % decorator
308                )
309
310        self.printer.start_source(node.lineno)
311        self.printer.writelines(
312            "def %s(%s):" % (name, ",".join(args)),
313            # push new frame, assign current frame to __M_caller
314            "__M_caller = context.caller_stack._push_frame()",
315            "try:",
316        )
317        if buffered or filtered or cached:
318            self.printer.writeline("context._push_buffer()")
319
320        self.identifier_stack.append(
321            self.compiler.identifiers.branch(self.node)
322        )
323        if (not self.in_def or self.node.is_block) and "**pageargs" in args:
324            self.identifier_stack[-1].argument_declared.add("pageargs")
325
326        if not self.in_def and (
327            len(self.identifiers.locally_assigned) > 0
328            or len(self.identifiers.argument_declared) > 0
329        ):
330            self.printer.writeline(
331                "__M_locals = __M_dict_builtin(%s)"
332                % ",".join(
333                    [
334                        "%s=%s" % (x, x)
335                        for x in self.identifiers.argument_declared
336                    ]
337                )
338            )
339
340        self.write_variable_declares(self.identifiers, toplevel=True)
341
342        for n in self.node.nodes:
343            n.accept_visitor(self)
344
345        self.write_def_finish(self.node, buffered, filtered, cached)
346        self.printer.writeline(None)
347        self.printer.write_blanks(2)
348        if cached:
349            self.write_cache_decorator(
350                node, name, args, buffered, self.identifiers, toplevel=True
351            )
352
353    def write_module_code(self, module_code):
354        """write module-level template code, i.e. that which
355        is enclosed in <%! %> tags in the template."""
356        for n in module_code:
357            self.printer.write_indented_block(n.text, starting_lineno=n.lineno)
358
359    def write_inherit(self, node):
360        """write the module-level inheritance-determination callable."""
361
362        self.printer.writelines(
363            "def _mako_inherit(template, context):",
364            "_mako_generate_namespaces(context)",
365            "return runtime._inherit_from(context, %s, _template_uri)"
366            % (node.parsed_attributes["file"]),
367            None,
368        )
369
370    def write_namespaces(self, namespaces):
371        """write the module-level namespace-generating callable."""
372        self.printer.writelines(
373            "def _mako_get_namespace(context, name):",
374            "try:",
375            "return context.namespaces[(__name__, name)]",
376            "except KeyError:",
377            "_mako_generate_namespaces(context)",
378            "return context.namespaces[(__name__, name)]",
379            None,
380            None,
381        )
382        self.printer.writeline("def _mako_generate_namespaces(context):")
383
384        for node in namespaces.values():
385            if "import" in node.attributes:
386                self.compiler.has_ns_imports = True
387            self.printer.start_source(node.lineno)
388            if len(node.nodes):
389                self.printer.writeline("def make_namespace():")
390                export = []
391                identifiers = self.compiler.identifiers.branch(node)
392                self.in_def = True
393
394                class NSDefVisitor(object):
395                    def visitDefTag(s, node):
396                        s.visitDefOrBase(node)
397
398                    def visitBlockTag(s, node):
399                        s.visitDefOrBase(node)
400
401                    def visitDefOrBase(s, node):
402                        if node.is_anonymous:
403                            raise exceptions.CompileException(
404                                "Can't put anonymous blocks inside "
405                                "<%namespace>",
406                                **node.exception_kwargs
407                            )
408                        self.write_inline_def(node, identifiers, nested=False)
409                        export.append(node.funcname)
410
411                vis = NSDefVisitor()
412                for n in node.nodes:
413                    n.accept_visitor(vis)
414                self.printer.writeline("return [%s]" % (",".join(export)))
415                self.printer.writeline(None)
416                self.in_def = False
417                callable_name = "make_namespace()"
418            else:
419                callable_name = "None"
420
421            if "file" in node.parsed_attributes:
422                self.printer.writeline(
423                    "ns = runtime.TemplateNamespace(%r,"
424                    " context._clean_inheritance_tokens(),"
425                    " templateuri=%s, callables=%s, "
426                    " calling_uri=_template_uri)"
427                    % (
428                        node.name,
429                        node.parsed_attributes.get("file", "None"),
430                        callable_name,
431                    )
432                )
433            elif "module" in node.parsed_attributes:
434                self.printer.writeline(
435                    "ns = runtime.ModuleNamespace(%r,"
436                    " context._clean_inheritance_tokens(),"
437                    " callables=%s, calling_uri=_template_uri,"
438                    " module=%s)"
439                    % (
440                        node.name,
441                        callable_name,
442                        node.parsed_attributes.get("module", "None"),
443                    )
444                )
445            else:
446                self.printer.writeline(
447                    "ns = runtime.Namespace(%r,"
448                    " context._clean_inheritance_tokens(),"
449                    " callables=%s, calling_uri=_template_uri)"
450                    % (node.name, callable_name)
451                )
452            if eval(node.attributes.get("inheritable", "False")):
453                self.printer.writeline("context['self'].%s = ns" % (node.name))
454
455            self.printer.writeline(
456                "context.namespaces[(__name__, %s)] = ns" % repr(node.name)
457            )
458            self.printer.write_blanks(1)
459        if not len(namespaces):
460            self.printer.writeline("pass")
461        self.printer.writeline(None)
462
463    def write_variable_declares(self, identifiers, toplevel=False, limit=None):
464        """write variable declarations at the top of a function.
465
466        the variable declarations are in the form of callable
467        definitions for defs and/or name lookup within the
468        function's context argument. the names declared are based
469        on the names that are referenced in the function body,
470        which don't otherwise have any explicit assignment
471        operation. names that are assigned within the body are
472        assumed to be locally-scoped variables and are not
473        separately declared.
474
475        for def callable definitions, if the def is a top-level
476        callable then a 'stub' callable is generated which wraps
477        the current Context into a closure. if the def is not
478        top-level, it is fully rendered as a local closure.
479
480        """
481
482        # collection of all defs available to us in this scope
483        comp_idents = dict([(c.funcname, c) for c in identifiers.defs])
484        to_write = set()
485
486        # write "context.get()" for all variables we are going to
487        # need that arent in the namespace yet
488        to_write = to_write.union(identifiers.undeclared)
489
490        # write closure functions for closures that we define
491        # right here
492        to_write = to_write.union(
493            [c.funcname for c in identifiers.closuredefs.values()]
494        )
495
496        # remove identifiers that are declared in the argument
497        # signature of the callable
498        to_write = to_write.difference(identifiers.argument_declared)
499
500        # remove identifiers that we are going to assign to.
501        # in this way we mimic Python's behavior,
502        # i.e. assignment to a variable within a block
503        # means that variable is now a "locally declared" var,
504        # which cannot be referenced beforehand.
505        to_write = to_write.difference(identifiers.locally_declared)
506
507        if self.compiler.enable_loop:
508            has_loop = "loop" in to_write
509            to_write.discard("loop")
510        else:
511            has_loop = False
512
513        # if a limiting set was sent, constraint to those items in that list
514        # (this is used for the caching decorator)
515        if limit is not None:
516            to_write = to_write.intersection(limit)
517
518        if toplevel and getattr(self.compiler, "has_ns_imports", False):
519            self.printer.writeline("_import_ns = {}")
520            self.compiler.has_imports = True
521            for ident, ns in self.compiler.namespaces.items():
522                if "import" in ns.attributes:
523                    self.printer.writeline(
524                        "_mako_get_namespace(context, %r)."
525                        "_populate(_import_ns, %r)"
526                        % (
527                            ident,
528                            re.split(r"\s*,\s*", ns.attributes["import"]),
529                        )
530                    )
531
532        if has_loop:
533            self.printer.writeline("loop = __M_loop = runtime.LoopStack()")
534
535        for ident in to_write:
536            if ident in comp_idents:
537                comp = comp_idents[ident]
538                if comp.is_block:
539                    if not comp.is_anonymous:
540                        self.write_def_decl(comp, identifiers)
541                    else:
542                        self.write_inline_def(comp, identifiers, nested=True)
543                else:
544                    if comp.is_root():
545                        self.write_def_decl(comp, identifiers)
546                    else:
547                        self.write_inline_def(comp, identifiers, nested=True)
548
549            elif ident in self.compiler.namespaces:
550                self.printer.writeline(
551                    "%s = _mako_get_namespace(context, %r)" % (ident, ident)
552                )
553            else:
554                if getattr(self.compiler, "has_ns_imports", False):
555                    if self.compiler.strict_undefined:
556                        self.printer.writelines(
557                            "%s = _import_ns.get(%r, UNDEFINED)"
558                            % (ident, ident),
559                            "if %s is UNDEFINED:" % ident,
560                            "try:",
561                            "%s = context[%r]" % (ident, ident),
562                            "except KeyError:",
563                            "raise NameError(\"'%s' is not defined\")" % ident,
564                            None,
565                            None,
566                        )
567                    else:
568                        self.printer.writeline(
569                            "%s = _import_ns.get"
570                            "(%r, context.get(%r, UNDEFINED))"
571                            % (ident, ident, ident)
572                        )
573                else:
574                    if self.compiler.strict_undefined:
575                        self.printer.writelines(
576                            "try:",
577                            "%s = context[%r]" % (ident, ident),
578                            "except KeyError:",
579                            "raise NameError(\"'%s' is not defined\")" % ident,
580                            None,
581                        )
582                    else:
583                        self.printer.writeline(
584                            "%s = context.get(%r, UNDEFINED)" % (ident, ident)
585                        )
586
587        self.printer.writeline("__M_writer = context.writer()")
588
589    def write_def_decl(self, node, identifiers):
590        """write a locally-available callable referencing a top-level def"""
591        funcname = node.funcname
592        namedecls = node.get_argument_expressions()
593        nameargs = node.get_argument_expressions(as_call=True)
594
595        if not self.in_def and (
596            len(self.identifiers.locally_assigned) > 0
597            or len(self.identifiers.argument_declared) > 0
598        ):
599            nameargs.insert(0, "context._locals(__M_locals)")
600        else:
601            nameargs.insert(0, "context")
602        self.printer.writeline("def %s(%s):" % (funcname, ",".join(namedecls)))
603        self.printer.writeline(
604            "return render_%s(%s)" % (funcname, ",".join(nameargs))
605        )
606        self.printer.writeline(None)
607
608    def write_inline_def(self, node, identifiers, nested):
609        """write a locally-available def callable inside an enclosing def."""
610
611        namedecls = node.get_argument_expressions()
612
613        decorator = node.decorator
614        if decorator:
615            self.printer.writeline(
616                "@runtime._decorate_inline(context, %s)" % decorator
617            )
618        self.printer.writeline(
619            "def %s(%s):" % (node.funcname, ",".join(namedecls))
620        )
621        filtered = len(node.filter_args.args) > 0
622        buffered = eval(node.attributes.get("buffered", "False"))
623        cached = eval(node.attributes.get("cached", "False"))
624        self.printer.writelines(
625            # push new frame, assign current frame to __M_caller
626            "__M_caller = context.caller_stack._push_frame()",
627            "try:",
628        )
629        if buffered or filtered or cached:
630            self.printer.writelines("context._push_buffer()")
631
632        identifiers = identifiers.branch(node, nested=nested)
633
634        self.write_variable_declares(identifiers)
635
636        self.identifier_stack.append(identifiers)
637        for n in node.nodes:
638            n.accept_visitor(self)
639        self.identifier_stack.pop()
640
641        self.write_def_finish(node, buffered, filtered, cached)
642        self.printer.writeline(None)
643        if cached:
644            self.write_cache_decorator(
645                node,
646                node.funcname,
647                namedecls,
648                False,
649                identifiers,
650                inline=True,
651                toplevel=False,
652            )
653
654    def write_def_finish(
655        self, node, buffered, filtered, cached, callstack=True
656    ):
657        """write the end section of a rendering function, either outermost or
658        inline.
659
660        this takes into account if the rendering function was filtered,
661        buffered, etc.  and closes the corresponding try: block if any, and
662        writes code to retrieve captured content, apply filters, send proper
663        return value."""
664
665        if not buffered and not cached and not filtered:
666            self.printer.writeline("return ''")
667            if callstack:
668                self.printer.writelines(
669                    "finally:", "context.caller_stack._pop_frame()", None
670                )
671
672        if buffered or filtered or cached:
673            if buffered or cached:
674                # in a caching scenario, don't try to get a writer
675                # from the context after popping; assume the caching
676                # implemenation might be using a context with no
677                # extra buffers
678                self.printer.writelines(
679                    "finally:", "__M_buf = context._pop_buffer()"
680                )
681            else:
682                self.printer.writelines(
683                    "finally:",
684                    "__M_buf, __M_writer = context._pop_buffer_and_writer()",
685                )
686
687            if callstack:
688                self.printer.writeline("context.caller_stack._pop_frame()")
689
690            s = "__M_buf.getvalue()"
691            if filtered:
692                s = self.create_filter_callable(
693                    node.filter_args.args, s, False
694                )
695            self.printer.writeline(None)
696            if buffered and not cached:
697                s = self.create_filter_callable(
698                    self.compiler.buffer_filters, s, False
699                )
700            if buffered or cached:
701                self.printer.writeline("return %s" % s)
702            else:
703                self.printer.writelines("__M_writer(%s)" % s, "return ''")
704
705    def write_cache_decorator(
706        self,
707        node_or_pagetag,
708        name,
709        args,
710        buffered,
711        identifiers,
712        inline=False,
713        toplevel=False,
714    ):
715        """write a post-function decorator to replace a rendering
716            callable with a cached version of itself."""
717
718        self.printer.writeline("__M_%s = %s" % (name, name))
719        cachekey = node_or_pagetag.parsed_attributes.get(
720            "cache_key", repr(name)
721        )
722
723        cache_args = {}
724        if self.compiler.pagetag is not None:
725            cache_args.update(
726                (pa[6:], self.compiler.pagetag.parsed_attributes[pa])
727                for pa in self.compiler.pagetag.parsed_attributes
728                if pa.startswith("cache_") and pa != "cache_key"
729            )
730        cache_args.update(
731            (pa[6:], node_or_pagetag.parsed_attributes[pa])
732            for pa in node_or_pagetag.parsed_attributes
733            if pa.startswith("cache_") and pa != "cache_key"
734        )
735        if "timeout" in cache_args:
736            cache_args["timeout"] = int(eval(cache_args["timeout"]))
737
738        self.printer.writeline("def %s(%s):" % (name, ",".join(args)))
739
740        # form "arg1, arg2, arg3=arg3, arg4=arg4", etc.
741        pass_args = [
742            "%s=%s" % ((a.split("=")[0],) * 2) if "=" in a else a for a in args
743        ]
744
745        self.write_variable_declares(
746            identifiers,
747            toplevel=toplevel,
748            limit=node_or_pagetag.undeclared_identifiers(),
749        )
750        if buffered:
751            s = (
752                "context.get('local')."
753                "cache._ctx_get_or_create("
754                "%s, lambda:__M_%s(%s),  context, %s__M_defname=%r)"
755                % (
756                    cachekey,
757                    name,
758                    ",".join(pass_args),
759                    "".join(
760                        ["%s=%s, " % (k, v) for k, v in cache_args.items()]
761                    ),
762                    name,
763                )
764            )
765            # apply buffer_filters
766            s = self.create_filter_callable(
767                self.compiler.buffer_filters, s, False
768            )
769            self.printer.writelines("return " + s, None)
770        else:
771            self.printer.writelines(
772                "__M_writer(context.get('local')."
773                "cache._ctx_get_or_create("
774                "%s, lambda:__M_%s(%s), context, %s__M_defname=%r))"
775                % (
776                    cachekey,
777                    name,
778                    ",".join(pass_args),
779                    "".join(
780                        ["%s=%s, " % (k, v) for k, v in cache_args.items()]
781                    ),
782                    name,
783                ),
784                "return ''",
785                None,
786            )
787
788    def create_filter_callable(self, args, target, is_expression):
789        """write a filter-applying expression based on the filters
790        present in the given filter names, adjusting for the global
791        'default' filter aliases as needed."""
792
793        def locate_encode(name):
794            if re.match(r"decode\..+", name):
795                return "filters." + name
796            elif self.compiler.disable_unicode:
797                return filters.NON_UNICODE_ESCAPES.get(name, name)
798            else:
799                return filters.DEFAULT_ESCAPES.get(name, name)
800
801        if "n" not in args:
802            if is_expression:
803                if self.compiler.pagetag:
804                    args = self.compiler.pagetag.filter_args.args + args
805                if self.compiler.default_filters and "n" not in args:
806                    args = self.compiler.default_filters + args
807        for e in args:
808            # if filter given as a function, get just the identifier portion
809            if e == "n":
810                continue
811            m = re.match(r"(.+?)(\(.*\))", e)
812            if m:
813                ident, fargs = m.group(1, 2)
814                f = locate_encode(ident)
815                e = f + fargs
816            else:
817                e = locate_encode(e)
818                assert e is not None
819            target = "%s(%s)" % (e, target)
820        return target
821
822    def visitExpression(self, node):
823        self.printer.start_source(node.lineno)
824        if (
825            len(node.escapes)
826            or (
827                self.compiler.pagetag is not None
828                and len(self.compiler.pagetag.filter_args.args)
829            )
830            or len(self.compiler.default_filters)
831        ):
832
833            s = self.create_filter_callable(
834                node.escapes_code.args, "%s" % node.text, True
835            )
836            self.printer.writeline("__M_writer(%s)" % s)
837        else:
838            self.printer.writeline("__M_writer(%s)" % node.text)
839
840    def visitControlLine(self, node):
841        if node.isend:
842            self.printer.writeline(None)
843            if node.has_loop_context:
844                self.printer.writeline("finally:")
845                self.printer.writeline("loop = __M_loop._exit()")
846                self.printer.writeline(None)
847        else:
848            self.printer.start_source(node.lineno)
849            if self.compiler.enable_loop and node.keyword == "for":
850                text = mangle_mako_loop(node, self.printer)
851            else:
852                text = node.text
853            self.printer.writeline(text)
854            children = node.get_children()
855            # this covers the three situations where we want to insert a pass:
856            #    1) a ternary control line with no children,
857            #    2) a primary control line with nothing but its own ternary
858            #          and end control lines, and
859            #    3) any control line with no content other than comments
860            if not children or (
861                compat.all(
862                    isinstance(c, (parsetree.Comment, parsetree.ControlLine))
863                    for c in children
864                )
865                and compat.all(
866                    (node.is_ternary(c.keyword) or c.isend)
867                    for c in children
868                    if isinstance(c, parsetree.ControlLine)
869                )
870            ):
871                self.printer.writeline("pass")
872
873    def visitText(self, node):
874        self.printer.start_source(node.lineno)
875        self.printer.writeline("__M_writer(%s)" % repr(node.content))
876
877    def visitTextTag(self, node):
878        filtered = len(node.filter_args.args) > 0
879        if filtered:
880            self.printer.writelines(
881                "__M_writer = context._push_writer()", "try:"
882            )
883        for n in node.nodes:
884            n.accept_visitor(self)
885        if filtered:
886            self.printer.writelines(
887                "finally:",
888                "__M_buf, __M_writer = context._pop_buffer_and_writer()",
889                "__M_writer(%s)"
890                % self.create_filter_callable(
891                    node.filter_args.args, "__M_buf.getvalue()", False
892                ),
893                None,
894            )
895
896    def visitCode(self, node):
897        if not node.ismodule:
898            self.printer.write_indented_block(
899                node.text, starting_lineno=node.lineno
900            )
901
902            if not self.in_def and len(self.identifiers.locally_assigned) > 0:
903                # if we are the "template" def, fudge locally
904                # declared/modified variables into the "__M_locals" dictionary,
905                # which is used for def calls within the same template,
906                # to simulate "enclosing scope"
907                self.printer.writeline(
908                    "__M_locals_builtin_stored = __M_locals_builtin()"
909                )
910                self.printer.writeline(
911                    "__M_locals.update(__M_dict_builtin([(__M_key,"
912                    " __M_locals_builtin_stored[__M_key]) for __M_key in"
913                    " [%s] if __M_key in __M_locals_builtin_stored]))"
914                    % ",".join([repr(x) for x in node.declared_identifiers()])
915                )
916
917    def visitIncludeTag(self, node):
918        self.printer.start_source(node.lineno)
919        args = node.attributes.get("args")
920        if args:
921            self.printer.writeline(
922                "runtime._include_file(context, %s, _template_uri, %s)"
923                % (node.parsed_attributes["file"], args)
924            )
925        else:
926            self.printer.writeline(
927                "runtime._include_file(context, %s, _template_uri)"
928                % (node.parsed_attributes["file"])
929            )
930
931    def visitNamespaceTag(self, node):
932        pass
933
934    def visitDefTag(self, node):
935        pass
936
937    def visitBlockTag(self, node):
938        if node.is_anonymous:
939            self.printer.writeline("%s()" % node.funcname)
940        else:
941            nameargs = node.get_argument_expressions(as_call=True)
942            nameargs += ["**pageargs"]
943            self.printer.writeline(
944                "if 'parent' not in context._data or "
945                "not hasattr(context._data['parent'], '%s'):" % node.funcname
946            )
947            self.printer.writeline(
948                "context['self'].%s(%s)" % (node.funcname, ",".join(nameargs))
949            )
950            self.printer.writeline("\n")
951
952    def visitCallNamespaceTag(self, node):
953        # TODO: we can put namespace-specific checks here, such
954        # as ensure the given namespace will be imported,
955        # pre-import the namespace, etc.
956        self.visitCallTag(node)
957
958    def visitCallTag(self, node):
959        self.printer.writeline("def ccall(caller):")
960        export = ["body"]
961        callable_identifiers = self.identifiers.branch(node, nested=True)
962        body_identifiers = callable_identifiers.branch(node, nested=False)
963        # we want the 'caller' passed to ccall to be used
964        # for the body() function, but for other non-body()
965        # <%def>s within <%call> we want the current caller
966        # off the call stack (if any)
967        body_identifiers.add_declared("caller")
968
969        self.identifier_stack.append(body_identifiers)
970
971        class DefVisitor(object):
972            def visitDefTag(s, node):
973                s.visitDefOrBase(node)
974
975            def visitBlockTag(s, node):
976                s.visitDefOrBase(node)
977
978            def visitDefOrBase(s, node):
979                self.write_inline_def(node, callable_identifiers, nested=False)
980                if not node.is_anonymous:
981                    export.append(node.funcname)
982                # remove defs that are within the <%call> from the
983                # "closuredefs" defined in the body, so they dont render twice
984                if node.funcname in body_identifiers.closuredefs:
985                    del body_identifiers.closuredefs[node.funcname]
986
987        vis = DefVisitor()
988        for n in node.nodes:
989            n.accept_visitor(vis)
990        self.identifier_stack.pop()
991
992        bodyargs = node.body_decl.get_argument_expressions()
993        self.printer.writeline("def body(%s):" % ",".join(bodyargs))
994
995        # TODO: figure out best way to specify
996        # buffering/nonbuffering (at call time would be better)
997        buffered = False
998        if buffered:
999            self.printer.writelines("context._push_buffer()", "try:")
1000        self.write_variable_declares(body_identifiers)
1001        self.identifier_stack.append(body_identifiers)
1002
1003        for n in node.nodes:
1004            n.accept_visitor(self)
1005        self.identifier_stack.pop()
1006
1007        self.write_def_finish(node, buffered, False, False, callstack=False)
1008        self.printer.writelines(None, "return [%s]" % (",".join(export)), None)
1009
1010        self.printer.writelines(
1011            # push on caller for nested call
1012            "context.caller_stack.nextcaller = "
1013            "runtime.Namespace('caller', context, "
1014            "callables=ccall(__M_caller))",
1015            "try:",
1016        )
1017        self.printer.start_source(node.lineno)
1018        self.printer.writelines(
1019            "__M_writer(%s)"
1020            % self.create_filter_callable([], node.expression, True),
1021            "finally:",
1022            "context.caller_stack.nextcaller = None",
1023            None,
1024        )
1025
1026
1027class _Identifiers(object):
1028
1029    """tracks the status of identifier names as template code is rendered."""
1030
1031    def __init__(self, compiler, node=None, parent=None, nested=False):
1032        if parent is not None:
1033            # if we are the branch created in write_namespaces(),
1034            # we don't share any context from the main body().
1035            if isinstance(node, parsetree.NamespaceTag):
1036                self.declared = set()
1037                self.topleveldefs = util.SetLikeDict()
1038            else:
1039                # things that have already been declared
1040                # in an enclosing namespace (i.e. names we can just use)
1041                self.declared = (
1042                    set(parent.declared)
1043                    .union([c.name for c in parent.closuredefs.values()])
1044                    .union(parent.locally_declared)
1045                    .union(parent.argument_declared)
1046                )
1047
1048                # if these identifiers correspond to a "nested"
1049                # scope, it means whatever the parent identifiers
1050                # had as undeclared will have been declared by that parent,
1051                # and therefore we have them in our scope.
1052                if nested:
1053                    self.declared = self.declared.union(parent.undeclared)
1054
1055                # top level defs that are available
1056                self.topleveldefs = util.SetLikeDict(**parent.topleveldefs)
1057        else:
1058            self.declared = set()
1059            self.topleveldefs = util.SetLikeDict()
1060
1061        self.compiler = compiler
1062
1063        # things within this level that are referenced before they
1064        # are declared (e.g. assigned to)
1065        self.undeclared = set()
1066
1067        # things that are declared locally.  some of these things
1068        # could be in the "undeclared" list as well if they are
1069        # referenced before declared
1070        self.locally_declared = set()
1071
1072        # assignments made in explicit python blocks.
1073        # these will be propagated to
1074        # the context of local def calls.
1075        self.locally_assigned = set()
1076
1077        # things that are declared in the argument
1078        # signature of the def callable
1079        self.argument_declared = set()
1080
1081        # closure defs that are defined in this level
1082        self.closuredefs = util.SetLikeDict()
1083
1084        self.node = node
1085
1086        if node is not None:
1087            node.accept_visitor(self)
1088
1089        illegal_names = self.compiler.reserved_names.intersection(
1090            self.locally_declared
1091        )
1092        if illegal_names:
1093            raise exceptions.NameConflictError(
1094                "Reserved words declared in template: %s"
1095                % ", ".join(illegal_names)
1096            )
1097
1098    def branch(self, node, **kwargs):
1099        """create a new Identifiers for a new Node, with
1100          this Identifiers as the parent."""
1101
1102        return _Identifiers(self.compiler, node, self, **kwargs)
1103
1104    @property
1105    def defs(self):
1106        return set(self.topleveldefs.union(self.closuredefs).values())
1107
1108    def __repr__(self):
1109        return (
1110            "Identifiers(declared=%r, locally_declared=%r, "
1111            "undeclared=%r, topleveldefs=%r, closuredefs=%r, "
1112            "argumentdeclared=%r)"
1113            % (
1114                list(self.declared),
1115                list(self.locally_declared),
1116                list(self.undeclared),
1117                [c.name for c in self.topleveldefs.values()],
1118                [c.name for c in self.closuredefs.values()],
1119                self.argument_declared,
1120            )
1121        )
1122
1123    def check_declared(self, node):
1124        """update the state of this Identifiers with the undeclared
1125            and declared identifiers of the given node."""
1126
1127        for ident in node.undeclared_identifiers():
1128            if ident != "context" and ident not in self.declared.union(
1129                self.locally_declared
1130            ):
1131                self.undeclared.add(ident)
1132        for ident in node.declared_identifiers():
1133            self.locally_declared.add(ident)
1134
1135    def add_declared(self, ident):
1136        self.declared.add(ident)
1137        if ident in self.undeclared:
1138            self.undeclared.remove(ident)
1139
1140    def visitExpression(self, node):
1141        self.check_declared(node)
1142
1143    def visitControlLine(self, node):
1144        self.check_declared(node)
1145
1146    def visitCode(self, node):
1147        if not node.ismodule:
1148            self.check_declared(node)
1149            self.locally_assigned = self.locally_assigned.union(
1150                node.declared_identifiers()
1151            )
1152
1153    def visitNamespaceTag(self, node):
1154        # only traverse into the sub-elements of a
1155        # <%namespace> tag if we are the branch created in
1156        # write_namespaces()
1157        if self.node is node:
1158            for n in node.nodes:
1159                n.accept_visitor(self)
1160
1161    def _check_name_exists(self, collection, node):
1162        existing = collection.get(node.funcname)
1163        collection[node.funcname] = node
1164        if (
1165            existing is not None
1166            and existing is not node
1167            and (node.is_block or existing.is_block)
1168        ):
1169            raise exceptions.CompileException(
1170                "%%def or %%block named '%s' already "
1171                "exists in this template." % node.funcname,
1172                **node.exception_kwargs
1173            )
1174
1175    def visitDefTag(self, node):
1176        if node.is_root() and not node.is_anonymous:
1177            self._check_name_exists(self.topleveldefs, node)
1178        elif node is not self.node:
1179            self._check_name_exists(self.closuredefs, node)
1180
1181        for ident in node.undeclared_identifiers():
1182            if ident != "context" and ident not in self.declared.union(
1183                self.locally_declared
1184            ):
1185                self.undeclared.add(ident)
1186
1187        # visit defs only one level deep
1188        if node is self.node:
1189            for ident in node.declared_identifiers():
1190                self.argument_declared.add(ident)
1191
1192            for n in node.nodes:
1193                n.accept_visitor(self)
1194
1195    def visitBlockTag(self, node):
1196        if node is not self.node and not node.is_anonymous:
1197
1198            if isinstance(self.node, parsetree.DefTag):
1199                raise exceptions.CompileException(
1200                    "Named block '%s' not allowed inside of def '%s'"
1201                    % (node.name, self.node.name),
1202                    **node.exception_kwargs
1203                )
1204            elif isinstance(
1205                self.node, (parsetree.CallTag, parsetree.CallNamespaceTag)
1206            ):
1207                raise exceptions.CompileException(
1208                    "Named block '%s' not allowed inside of <%%call> tag"
1209                    % (node.name,),
1210                    **node.exception_kwargs
1211                )
1212
1213        for ident in node.undeclared_identifiers():
1214            if ident != "context" and ident not in self.declared.union(
1215                self.locally_declared
1216            ):
1217                self.undeclared.add(ident)
1218
1219        if not node.is_anonymous:
1220            self._check_name_exists(self.topleveldefs, node)
1221            self.undeclared.add(node.funcname)
1222        elif node is not self.node:
1223            self._check_name_exists(self.closuredefs, node)
1224        for ident in node.declared_identifiers():
1225            self.argument_declared.add(ident)
1226        for n in node.nodes:
1227            n.accept_visitor(self)
1228
1229    def visitTextTag(self, node):
1230        for ident in node.undeclared_identifiers():
1231            if ident != "context" and ident not in self.declared.union(
1232                self.locally_declared
1233            ):
1234                self.undeclared.add(ident)
1235
1236    def visitIncludeTag(self, node):
1237        self.check_declared(node)
1238
1239    def visitPageTag(self, node):
1240        for ident in node.declared_identifiers():
1241            self.argument_declared.add(ident)
1242        self.check_declared(node)
1243
1244    def visitCallNamespaceTag(self, node):
1245        self.visitCallTag(node)
1246
1247    def visitCallTag(self, node):
1248        if node is self.node:
1249            for ident in node.undeclared_identifiers():
1250                if ident != "context" and ident not in self.declared.union(
1251                    self.locally_declared
1252                ):
1253                    self.undeclared.add(ident)
1254            for ident in node.declared_identifiers():
1255                self.argument_declared.add(ident)
1256            for n in node.nodes:
1257                n.accept_visitor(self)
1258        else:
1259            for ident in node.undeclared_identifiers():
1260                if ident != "context" and ident not in self.declared.union(
1261                    self.locally_declared
1262                ):
1263                    self.undeclared.add(ident)
1264
1265
1266_FOR_LOOP = re.compile(
1267    r"^for\s+((?:\(?)\s*[A-Za-z_][A-Za-z_0-9]*"
1268    r"(?:\s*,\s*(?:[A-Za-z_][A-Za-z0-9_]*),??)*\s*(?:\)?))\s+in\s+(.*):"
1269)
1270
1271
1272def mangle_mako_loop(node, printer):
1273    """converts a for loop into a context manager wrapped around a for loop
1274    when access to the `loop` variable has been detected in the for loop body
1275    """
1276    loop_variable = LoopVariable()
1277    node.accept_visitor(loop_variable)
1278    if loop_variable.detected:
1279        node.nodes[-1].has_loop_context = True
1280        match = _FOR_LOOP.match(node.text)
1281        if match:
1282            printer.writelines(
1283                "loop = __M_loop._enter(%s)" % match.group(2),
1284                "try:"
1285                # 'with __M_loop(%s) as loop:' % match.group(2)
1286            )
1287            text = "for %s in loop:" % match.group(1)
1288        else:
1289            raise SyntaxError("Couldn't apply loop context: %s" % node.text)
1290    else:
1291        text = node.text
1292    return text
1293
1294
1295class LoopVariable(object):
1296
1297    """A node visitor which looks for the name 'loop' within undeclared
1298    identifiers."""
1299
1300    def __init__(self):
1301        self.detected = False
1302
1303    def _loop_reference_detected(self, node):
1304        if "loop" in node.undeclared_identifiers():
1305            self.detected = True
1306        else:
1307            for n in node.get_children():
1308                n.accept_visitor(self)
1309
1310    def visitControlLine(self, node):
1311        self._loop_reference_detected(node)
1312
1313    def visitCode(self, node):
1314        self._loop_reference_detected(node)
1315
1316    def visitExpression(self, node):
1317        self._loop_reference_detected(node)
1318